Skip to content

Commit

Permalink
Custom interface serializers (#104)
Browse files Browse the repository at this point in the history
This PR changes the way custom serializers work. If you register
serialization routines for type `T` where `T` is an interface type, the
serialization routines will now be used for values of type `T` _and
every type that implements `T`_.

The serialization layer will no longer recursively scan values of type
`T` when a custom serializer has been registered for `T` or for an
interface that `T` implements. The assumption is that the custom
serializers will not serialize the underlying memory regions in a way
that makes preserving pointers outside `T` that alias those regions
possible, hence scanning is not necessary.

These two changes combined allow users to avoid scanning and serializing
complex types that the serialization layer cannot currently handle. One
example from https://github.com/protocolbuffers/protobuf-go is the
`proto.Message` interface and the generated structs that implement it,
which use all sorts of unsafe hacks (e.g. to prevent copying and
comparisons, and to embed information used for Go and protobuf
reflection). The protobuf library provides a native way to serialize any
sort of `proto.Message` (via `anypb`). Users can now register custom
serialization routines for `proto.Message` and don't have to manually
register routines for every type that implements `proto.Message`.
  • Loading branch information
chriso authored Nov 9, 2023
2 parents 5b44969 + c0a8c25 commit 7cf7aca
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 13 deletions.
6 changes: 3 additions & 3 deletions types/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func deserializeType(d *Deserializer) reflect.Type {

func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) {
if serde, ok := types.serdeOf(t); ok {
serde.ser(s, p)
serde.ser(s, t, p)
return
}

Expand Down Expand Up @@ -93,7 +93,7 @@ func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) {

func deserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
if serde, ok := types.serdeOf(t); ok {
serde.des(d, p)
serde.des(d, t, p)
return
}

Expand Down Expand Up @@ -653,7 +653,7 @@ func deserializeFunc(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
if fn.Type == nil {
panic(name + ": function type is missing")
}
if fn.Type != t {
if !t.AssignableTo(fn.Type) {
panic(name + ": function type mismatch: " + fn.Type.String() + " != " + t.String())
}

Expand Down
6 changes: 6 additions & 0 deletions types/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) {
return
}

// Don't scan types where custom serialization routines
// have been registered.
if _, ok := types.serdeOf(t); ok {
return
}

r := reflect.NewAt(t, p)
if _, ok := s.scanptrs[r]; ok {
return
Expand Down
53 changes: 43 additions & 10 deletions types/typemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,53 @@ func registerSerde[T any](tm *typemap,

t := reflect.TypeOf((*T)(nil)).Elem()

s := func(s *Serializer, p unsafe.Pointer) {
s := func(s *Serializer, actualType reflect.Type, p unsafe.Pointer) {
if t != actualType {
v := reflect.NewAt(actualType, p).Elem()
box := reflect.New(t)
box.Elem().Set(v.Convert(t))
p = box.UnsafePointer()
}
if err := serializer(s, (*T)(p)); err != nil {
panic(fmt.Errorf("serializing %s: %w", t, err))
}
}

d := func(d *Deserializer, p unsafe.Pointer) {
if err := deserializer(d, (*T)(p)); err != nil {
panic(fmt.Errorf("deserializing %s: %w", t, err))
d := func(d *Deserializer, actualType reflect.Type, p unsafe.Pointer) {
if t != actualType {
box := reflect.New(t)
boxp := box.UnsafePointer()
if err := deserializer(d, (*T)(boxp)); err != nil {
panic(fmt.Errorf("deserializing %s: %w", t, err))
}
v := reflect.NewAt(actualType, p)
reinterpreted := reflect.ValueOf(box.Elem().Interface())
v.Elem().Set(reinterpreted)
} else {
if err := deserializer(d, (*T)(p)); err != nil {
panic(fmt.Errorf("deserializing %s: %w", t, err))
}
}
}

tm.attach(t, s, d)
}

type serializerFunc func(*Serializer, unsafe.Pointer)
type deserializerFunc func(d *Deserializer, p unsafe.Pointer)
type serializerFunc func(*Serializer, reflect.Type, unsafe.Pointer)
type deserializerFunc func(*Deserializer, reflect.Type, unsafe.Pointer)

type serde struct {
id int
t reflect.Type
ser serializerFunc
des deserializerFunc
}

type typemap struct {
custom []reflect.Type
cache doublemap[reflect.Type, *typeinfo]
serdes map[reflect.Type]serde
custom []reflect.Type
cache doublemap[reflect.Type, *typeinfo]
serdes map[reflect.Type]serde
interfaces []serde
}

func newTypemap() *typemap {
Expand All @@ -99,15 +118,29 @@ func (m *typemap) attach(t reflect.Type, ser serializerFunc, des deserializerFun
s.id = len(m.custom)
m.custom = append(m.custom, t)
}
s.t = t
s.ser = ser
s.des = des

m.serdes[t] = s

if t.Kind() == reflect.Interface {
m.interfaces = append(m.interfaces, s)
}
}

func (m *typemap) serdeOf(x reflect.Type) (serde, bool) {
s, ok := m.serdes[x]
return s, ok
if ok {
return s, true
}
for i := range m.interfaces {
s := m.interfaces[i]
if x.Implements(s.t) {
return s, true
}
}
return serde{}, false
}

type doublemap[K, V comparable] struct {
Expand Down

0 comments on commit 7cf7aca

Please sign in to comment.