Skip to content

Commit

Permalink
Support serializing primitive reflect values
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Nov 7, 2023
1 parent d8b85a2 commit cff35ab
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
124 changes: 124 additions & 0 deletions types/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) {
return
}

switch t {
case reflectValueType:
serializeReflectValue(s, *(*reflect.Value)(p))
return
}

switch t.Kind() {
case reflect.Invalid:
panic(fmt.Errorf("can't serialize reflect.Invalid"))
Expand Down Expand Up @@ -89,6 +95,12 @@ func deserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
return
}

switch t {
case reflectValueType:
deserializeReflectValue(d, t, p)
return
}

switch t.Kind() {
case reflect.Invalid:
panic(fmt.Errorf("can't deserialize reflect.Invalid"))
Expand Down Expand Up @@ -147,6 +159,118 @@ func deserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
}
}

var reflectValueType = reflect.TypeOf(reflect.Value{})

func serializeReflectValue(s *Serializer, v reflect.Value) {
t := v.Type()
serializeType(s, t)

var p unsafe.Pointer
switch t.Kind() {
case reflect.Invalid:
panic(fmt.Errorf("can't serialize reflect.Invalid"))
case reflect.Bool:
b := v.Bool()
p = unsafe.Pointer(&b)
case reflect.Int:
i := int(v.Int())
p = unsafe.Pointer(&i)
case reflect.Int8:
i := int8(v.Int())
p = unsafe.Pointer(&i)
case reflect.Int16:
i := int16(v.Int())
p = unsafe.Pointer(&i)
case reflect.Int32:
i := int32(v.Int())
p = unsafe.Pointer(&i)
case reflect.Int64:
i := int64(v.Int())
p = unsafe.Pointer(&i)
case reflect.Uint:
u := uint(v.Uint())
p = unsafe.Pointer(&u)
case reflect.Uint8:
u := uint8(v.Uint())
p = unsafe.Pointer(&u)
case reflect.Uint16:
u := uint16(v.Uint())
p = unsafe.Pointer(&u)
case reflect.Uint32:
u := uint32(v.Uint())
p = unsafe.Pointer(&u)
case reflect.Uint64:
u := uint64(v.Uint())
p = unsafe.Pointer(&u)
case reflect.Float32:
f := float32(v.Float())
p = unsafe.Pointer(&f)
case reflect.Float64:
f := float64(v.Float())
p = unsafe.Pointer(&f)
case reflect.Complex64:
c := complex64(v.Complex())
p = unsafe.Pointer(&c)
case reflect.Complex128:
c := complex128(v.Complex())
p = unsafe.Pointer(&c)
case reflect.String:
s := v.String()
p = unsafe.Pointer(&s)
default:
panic(fmt.Sprintf("not implemented: serializing reflect.Value with type %s", t))
}
serializeAny(s, t, p)
}

func deserializeReflectValue(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
rt := deserializeType(d)
deserializeAny(d, rt, p)

var v reflect.Value
switch rt.Kind() {
case reflect.Invalid:
panic(fmt.Errorf("can't deserialize reflect.Invalid"))
case reflect.Bool:
v = reflect.ValueOf(*(*bool)(p))
case reflect.Int:
v = reflect.ValueOf(*(*int)(p))
case reflect.Int8:
v = reflect.ValueOf(*(*int8)(p))
case reflect.Int16:
v = reflect.ValueOf(*(*int16)(p))
case reflect.Int32:
v = reflect.ValueOf(*(*int32)(p))
case reflect.Int64:
v = reflect.ValueOf(*(*int64)(p))
case reflect.Uint:
v = reflect.ValueOf(*(*uint)(p))
case reflect.Uint8:
v = reflect.ValueOf(*(*uint8)(p))
case reflect.Uint16:
v = reflect.ValueOf(*(*uint16)(p))
case reflect.Uint32:
v = reflect.ValueOf(*(*uint32)(p))
case reflect.Uint64:
v = reflect.ValueOf(*(*uint64)(p))
case reflect.Float32:
v = reflect.ValueOf(*(*float32)(p))
case reflect.Float64:
v = reflect.ValueOf(*(*float64)(p))
case reflect.Complex64:
v = reflect.ValueOf(*(*complex64)(p))
case reflect.Complex128:
v = reflect.ValueOf(*(*complex128)(p))
case reflect.String:
v = reflect.ValueOf(*(*string)(p))
default:
panic(fmt.Sprintf("not implemented: deserializing reflect.Value with type %s", t))
}

r := reflect.NewAt(t, p)
r.Elem().Set(reflect.ValueOf(v))
}

func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) {
// If this is a nil pointer, write it as such.
if p == nil {
Expand Down
41 changes: 41 additions & 0 deletions types/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"math"
"net/http"
"reflect"
"strconv"
Expand Down Expand Up @@ -111,6 +112,20 @@ func TestReflect(t *testing.T) {
struct{}{},
errors.New("test"),
unsafe.Pointer(nil),
reflect.ValueOf("foo"),
reflect.ValueOf(true),
reflect.ValueOf(int(1)),
reflect.ValueOf(int8(math.MaxInt8)),
reflect.ValueOf(int16(-math.MaxInt16)),
reflect.ValueOf(int32(math.MaxInt32)),
reflect.ValueOf(int64(-math.MaxInt64)),
reflect.ValueOf(uint(1)),
reflect.ValueOf(uint8(math.MaxUint8)),
reflect.ValueOf(uint16(math.MaxUint16)),
reflect.ValueOf(uint32(math.MaxUint8)),
reflect.ValueOf(uint64(math.MaxUint64)),
reflect.ValueOf(float32(3.14)),
reflect.ValueOf(float64(math.MaxFloat64)),
}

for _, x := range cases {
Expand Down Expand Up @@ -638,13 +653,39 @@ func deepEqual(v1, v2 any) bool {
return false
}

if t1 == reflect.TypeOf(reflect.Value{}) {
return equalReflectValue(v1.(reflect.Value), v2.(reflect.Value))
}

if t1.Kind() == reflect.Func {
return FuncAddr(v1) == FuncAddr(v2)
}

return reflect.DeepEqual(v1, v2)
}

func equalReflectValue(v1, v2 reflect.Value) bool {
if v1.Type() != v2.Type() {
return false
}
switch v1.Kind() {
case reflect.Bool:
return v1.Bool() == v2.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v1.Uint() == v2.Uint()
case reflect.Float32, reflect.Float64:
return v1.Float() == v2.Float()
case reflect.Complex64, reflect.Complex128:
return v1.Complex() == v2.Complex()
case reflect.String:
return v1.String() == v2.String()
default:
panic(fmt.Sprintf("not implemented: comparison of reflect.Value with type %T", v1))
}
}

func assertRoundTrip[T any](t *testing.T, orig T) T {
t.Helper()

Expand Down

0 comments on commit cff35ab

Please sign in to comment.