Skip to content

Commit 7cf7aca

Browse files
authored
Custom interface serializers (#104)
This PR changes the way custom serializers work. If you register serialization routines for type `T` where `T` is an interface type, the serialization routines will now be used for values of type `T` _and every type that implements `T`_. The serialization layer will no longer recursively scan values of type `T` when a custom serializer has been registered for `T` or for an interface that `T` implements. The assumption is that the custom serializers will not serialize the underlying memory regions in a way that makes preserving pointers outside `T` that alias those regions possible, hence scanning is not necessary. These two changes combined allow users to avoid scanning and serializing complex types that the serialization layer cannot currently handle. One example from https://github.com/protocolbuffers/protobuf-go is the `proto.Message` interface and the generated structs that implement it, which use all sorts of unsafe hacks (e.g. to prevent copying and comparisons, and to embed information used for Go and protobuf reflection). The protobuf library provides a native way to serialize any sort of `proto.Message` (via `anypb`). Users can now register custom serialization routines for `proto.Message` and don't have to manually register routines for every type that implements `proto.Message`.
2 parents 5b44969 + c0a8c25 commit 7cf7aca

File tree

3 files changed

+52
-13
lines changed

3 files changed

+52
-13
lines changed

types/reflect.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func deserializeType(d *Deserializer) reflect.Type {
2020

2121
func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) {
2222
if serde, ok := types.serdeOf(t); ok {
23-
serde.ser(s, p)
23+
serde.ser(s, t, p)
2424
return
2525
}
2626

@@ -93,7 +93,7 @@ func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) {
9393

9494
func deserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
9595
if serde, ok := types.serdeOf(t); ok {
96-
serde.des(d, p)
96+
serde.des(d, t, p)
9797
return
9898
}
9999

@@ -653,7 +653,7 @@ func deserializeFunc(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
653653
if fn.Type == nil {
654654
panic(name + ": function type is missing")
655655
}
656-
if fn.Type != t {
656+
if !t.AssignableTo(fn.Type) {
657657
panic(name + ": function type mismatch: " + fn.Type.String() + " != " + t.String())
658658
}
659659

types/scan.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) {
243243
return
244244
}
245245

246+
// Don't scan types where custom serialization routines
247+
// have been registered.
248+
if _, ok := types.serdeOf(t); ok {
249+
return
250+
}
251+
246252
r := reflect.NewAt(t, p)
247253
if _, ok := s.scanptrs[r]; ok {
248254
return

types/typemap.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,34 +52,53 @@ func registerSerde[T any](tm *typemap,
5252

5353
t := reflect.TypeOf((*T)(nil)).Elem()
5454

55-
s := func(s *Serializer, p unsafe.Pointer) {
55+
s := func(s *Serializer, actualType reflect.Type, p unsafe.Pointer) {
56+
if t != actualType {
57+
v := reflect.NewAt(actualType, p).Elem()
58+
box := reflect.New(t)
59+
box.Elem().Set(v.Convert(t))
60+
p = box.UnsafePointer()
61+
}
5662
if err := serializer(s, (*T)(p)); err != nil {
5763
panic(fmt.Errorf("serializing %s: %w", t, err))
5864
}
5965
}
6066

61-
d := func(d *Deserializer, p unsafe.Pointer) {
62-
if err := deserializer(d, (*T)(p)); err != nil {
63-
panic(fmt.Errorf("deserializing %s: %w", t, err))
67+
d := func(d *Deserializer, actualType reflect.Type, p unsafe.Pointer) {
68+
if t != actualType {
69+
box := reflect.New(t)
70+
boxp := box.UnsafePointer()
71+
if err := deserializer(d, (*T)(boxp)); err != nil {
72+
panic(fmt.Errorf("deserializing %s: %w", t, err))
73+
}
74+
v := reflect.NewAt(actualType, p)
75+
reinterpreted := reflect.ValueOf(box.Elem().Interface())
76+
v.Elem().Set(reinterpreted)
77+
} else {
78+
if err := deserializer(d, (*T)(p)); err != nil {
79+
panic(fmt.Errorf("deserializing %s: %w", t, err))
80+
}
6481
}
6582
}
6683

6784
tm.attach(t, s, d)
6885
}
6986

70-
type serializerFunc func(*Serializer, unsafe.Pointer)
71-
type deserializerFunc func(d *Deserializer, p unsafe.Pointer)
87+
type serializerFunc func(*Serializer, reflect.Type, unsafe.Pointer)
88+
type deserializerFunc func(*Deserializer, reflect.Type, unsafe.Pointer)
7289

7390
type serde struct {
7491
id int
92+
t reflect.Type
7593
ser serializerFunc
7694
des deserializerFunc
7795
}
7896

7997
type typemap struct {
80-
custom []reflect.Type
81-
cache doublemap[reflect.Type, *typeinfo]
82-
serdes map[reflect.Type]serde
98+
custom []reflect.Type
99+
cache doublemap[reflect.Type, *typeinfo]
100+
serdes map[reflect.Type]serde
101+
interfaces []serde
83102
}
84103

85104
func newTypemap() *typemap {
@@ -99,15 +118,29 @@ func (m *typemap) attach(t reflect.Type, ser serializerFunc, des deserializerFun
99118
s.id = len(m.custom)
100119
m.custom = append(m.custom, t)
101120
}
121+
s.t = t
102122
s.ser = ser
103123
s.des = des
104124

105125
m.serdes[t] = s
126+
127+
if t.Kind() == reflect.Interface {
128+
m.interfaces = append(m.interfaces, s)
129+
}
106130
}
107131

108132
func (m *typemap) serdeOf(x reflect.Type) (serde, bool) {
109133
s, ok := m.serdes[x]
110-
return s, ok
134+
if ok {
135+
return s, true
136+
}
137+
for i := range m.interfaces {
138+
s := m.interfaces[i]
139+
if x.Implements(s.t) {
140+
return s, true
141+
}
142+
}
143+
return serde{}, false
111144
}
112145

113146
type doublemap[K, V comparable] struct {

0 commit comments

Comments
 (0)