From b66ac6c0896350d105d71bf1960eece62ebb0c3c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 10 Jun 2024 11:19:01 -0700 Subject: [PATCH] Fix for working with byte arrays (#958) Fix for conversion of bytes to a [n]byte array --- common/types/bytes.go | 12 +++++++- common/types/bytes_test.go | 18 ++++++++++++ common/types/provider.go | 8 +++++ common/types/provider_test.go | 2 ++ ext/native.go | 3 ++ ext/native_test.go | 55 ++++++++++++++++++++++++++--------- 6 files changed, 83 insertions(+), 15 deletions(-) diff --git a/common/types/bytes.go b/common/types/bytes.go index 5838755f..7e813e29 100644 --- a/common/types/bytes.go +++ b/common/types/bytes.go @@ -58,7 +58,17 @@ func (b Bytes) Compare(other ref.Val) ref.Val { // ConvertToNative implements the ref.Val interface method. func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) { switch typeDesc.Kind() { - case reflect.Array, reflect.Slice: + case reflect.Array: + if len(b) != typeDesc.Len() { + return nil, fmt.Errorf("[%d]byte not assignable to [%d]byte array", len(b), typeDesc.Len()) + } + refArrPtr := reflect.New(reflect.ArrayOf(len(b), typeDesc.Elem())) + refArr := refArrPtr.Elem() + for i, byt := range b { + refArr.Index(i).Set(reflect.ValueOf(byt).Convert(typeDesc.Elem())) + } + return refArr.Interface(), nil + case reflect.Slice: return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil case reflect.Ptr: switch typeDesc { diff --git a/common/types/bytes_test.go b/common/types/bytes_test.go index 4c94693f..498a168f 100644 --- a/common/types/bytes_test.go +++ b/common/types/bytes_test.go @@ -17,6 +17,7 @@ package types import ( "bytes" "reflect" + "strings" "testing" "google.golang.org/protobuf/proto" @@ -71,6 +72,23 @@ func TestBytesConvertToNative_ByteSlice(t *testing.T) { } } +func TestBytesConvertToNative_ByteArray(t *testing.T) { + val, err := Bytes("123").ConvertToNative(reflect.TypeOf([3]byte{})) + if err != nil { + t.Error("Got unexpected value, wanted []byte{49, 50, 51}", err, val) + } + if val.([3]byte) != [3]byte{49, 50, 51} { + t.Errorf("Got %v, wanted [3]byte{49, 50, 51}", val) + } +} + +func TestBytesConvertToNative_ByteArrayError(t *testing.T) { + _, err := Bytes("123").ConvertToNative(reflect.TypeOf([1]byte{})) + if !strings.Contains(err.Error(), "[3]byte not assignable to [1]byte") { + t.Errorf("Got unexpected error %v, wanted not assignable error", err) + } +} + func TestBytesConvertToNative_Error(t *testing.T) { val, err := Bytes("123").ConvertToNative(reflect.TypeOf("")) if err == nil { diff --git a/common/types/provider.go b/common/types/provider.go index c5ff05fd..936a4e28 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -585,6 +585,14 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) { refKind := refValue.Kind() switch refKind { case reflect.Array, reflect.Slice: + if refValue.Type().Elem() == reflect.TypeOf(byte(0)) { + if refValue.CanAddr() { + return Bytes(refValue.Bytes()), true + } + tmp := reflect.New(refValue.Type()) + tmp.Elem().Set(refValue) + return Bytes(tmp.Elem().Bytes()), true + } return NewDynamicList(a, v), true case reflect.Map: return NewDynamicMap(a, v), true diff --git a/common/types/provider_test.go b/common/types/provider_test.go index efe1244a..a2b2026a 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -728,6 +728,8 @@ func TestNativeToValue_Primitive(t *testing.T) { expectNativeToValue(t, float64(-5.5), Double(-5.5)) expectNativeToValue(t, "hello", String("hello")) expectNativeToValue(t, []byte("world"), Bytes("world")) + expectNativeToValue(t, [4]byte{1, 2, 3, 4}, Bytes([]byte{1, 2, 3, 4})) + expectNativeToValue(t, &[4]byte{1, 2, 3, 4}, Bytes([]byte{1, 2, 3, 4})) expectNativeToValue(t, time.Duration(500), Duration{Duration: time.Duration(500)}) expectNativeToValue(t, time.Unix(12345, 0), Timestamp{Time: time.Unix(12345, 0)}) expectNativeToValue(t, dpb.New(time.Duration(500)), Duration{Duration: time.Duration(500)}) diff --git a/ext/native.go b/ext/native.go index 0e4bd305..75dff5a0 100644 --- a/ext/native.go +++ b/ext/native.go @@ -330,6 +330,9 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val { case []byte: return tp.baseAdapter.NativeToValue(val) default: + if refVal.Type().Elem() == reflect.TypeOf(byte(0)) { + return tp.baseAdapter.NativeToValue(val) + } return types.NewDynamicList(tp, val) } case reflect.Map: diff --git a/ext/native_test.go b/ext/native_test.go index a44bdc04..55e5aa04 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -201,6 +201,8 @@ func TestNativeTypes(t *testing.T) { {expr: `[TestAllTypes{BoolVal: true}, TestAllTypes{BoolVal: false}].exists(t, t.BoolVal == true)`}, {expr: `[TestAllTypes{CustomName: 'Alice'}, TestAllTypes{CustomName: 'Bob'}].exists(t, t.CustomName == 'Alice')`}, {expr: `[TestAllTypes{custom_name: 'Alice'}, TestAllTypes{custom_name: 'Bob'}].exists(t, t.custom_name == 'Alice')`, envOpts: []any{ParseStructTags(true)}}, + {expr: `TestAllTypes{BytesArrayVal: b'1234'}.BytesArrayVal != b'123'`}, + {expr: `TestAllTypes{BytesArrayVal: b'1234'}.BytesArrayVal == b'1234'`}, { expr: `tests.all(t, t.Int32Val > 17)`, in: map[string]any{ @@ -577,30 +579,51 @@ func TestNativeTypesConvertToNative(t *testing.T) { env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{}))) adapter := env.CELTypeAdapter() conversions := []struct { - in any - out any - err string + in any + inType *cel.Type + out any + err string }{ { - in: &TestAllTypes{BoolVal: true}, - out: &TestAllTypes{BoolVal: true}, + in: &TestAllTypes{BoolVal: true}, + inType: cel.ObjectType("ext.TestAllTypes"), + out: &TestAllTypes{BoolVal: true}, }, { - in: TestAllTypes{BoolVal: true}, - out: &TestAllTypes{BoolVal: true}, + in: TestAllTypes{BoolVal: true}, + inType: cel.ObjectType("ext.TestAllTypes"), + out: &TestAllTypes{BoolVal: true}, }, { - in: &TestAllTypes{BoolVal: true}, - out: TestAllTypes{BoolVal: true}, + in: &TestAllTypes{BoolVal: true}, + inType: cel.ObjectType("ext.TestAllTypes"), + out: TestAllTypes{BoolVal: true}, }, { - in: nil, - out: types.NullValue, + in: nil, + inType: cel.NullType, + out: types.NullValue, }, { - in: &TestAllTypes{BoolVal: true}, - out: &proto3pb.TestAllTypes{}, - err: "type conversion error", + in: &TestAllTypes{BoolVal: true}, + inType: cel.ObjectType("ext.TestAllTypes"), + out: &proto3pb.TestAllTypes{}, + err: "type conversion error", + }, + { + in: [3]int32{1, 2, 3}, + inType: cel.ListType(cel.IntType), + out: []int32{1, 2, 3}, + }, + { + in: &[3]byte{1, 2, 3}, + inType: cel.BytesType, + out: []byte{1, 2, 3}, + }, + { + in: [3]byte{1, 2, 3}, + inType: cel.BytesType, + out: []byte{1, 2, 3}, }, } for _, c := range conversions { @@ -608,6 +631,9 @@ func TestNativeTypesConvertToNative(t *testing.T) { if types.IsError(inVal) { t.Fatalf("adapter.NativeToValue(%v) failed: %v", c.in, inVal) } + if inVal.Type().TypeName() != c.inType.TypeName() { + t.Fatalf("adapter.NativeToValue() got type %v, wanted type %v", inVal.Type(), c.inType) + } out, err := inVal.ConvertToNative(reflect.TypeOf(c.out)) if err != nil { if c.err != "" { @@ -848,6 +874,7 @@ type TestAllTypes struct { Uint64Val uint64 ListVal []*TestNestedType ArrayVal [1]*TestNestedType + BytesArrayVal [4]byte MapVal map[string]TestAllTypes PbVal *proto3pb.TestAllTypes CustomSliceVal []TestNestedSliceType