Skip to content

Commit

Permalink
simplify wrapper
Browse files Browse the repository at this point in the history
Signed-off-by: 彭锟 <[email protected]>
  • Loading branch information
kom0055 committed Apr 21, 2024
1 parent 77fc8f8 commit f089fb7
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 172 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*.dll
*.so
*.dylib
.DS_Store

# Test binary, built with `go test -c`
*.test
Expand Down
11 changes: 1 addition & 10 deletions core/api.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
package core

import (
"reflect"
)

func RegisterType(name string, id any) error {
val := reflect.ValueOf(id)
typ := val.Type()
if err := globalProvider.setBidNameType(name, typ); err != nil {
return err
}
return nil
return globalProvider.RegisterType(name, id)
}
13 changes: 2 additions & 11 deletions core/api_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@ package core

import (
"encoding/json"
"reflect"
)

func MarshalJson(in any) ([]byte, error) {
outIf, err := shuttleMarshal(in)
if err != nil {
return nil, err
}
return json.Marshal(outIf)
return globalProvider.MarshalJson(in)
}

func UnmarshalJson(out any, raw json.RawMessage) error {

return shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
return json.Unmarshal(raw, a)
})

return globalProvider.UnmarshalJson(out, raw)
}
17 changes: 2 additions & 15 deletions core/api_msgpack.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
package core

import (
"reflect"

"github.com/vmihailenco/msgpack/v5"
)

func MarshalMsgPack(in any) ([]byte, error) {
outIf, err := shuttleMarshal(in)
if err != nil {
return nil, err
}
return msgpack.Marshal(outIf)
return globalProvider.MarshalMsgPack(in)
}

func UnmarshalMsgPack(out any, b []byte) error {
return shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
return msgpack.Unmarshal(b, a)
})

return globalProvider.UnmarshalMsgPack(out, b)
}
10 changes: 2 additions & 8 deletions core/api_yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@ import (
)

func MarshalYaml(in any) (any, error) {
return shuttleMarshal(in)
return globalProvider.MarshalYaml(in)
}

func UnmarshalYaml(out any, unmarshal func(any) error) error {

return shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
if err := unmarshal(a); err != nil {
return ReplaceYAMLTypeError(err, inType, outTyp)
}
return nil
})
return globalProvider.UnmarshalYaml(out, unmarshal)

}

Expand Down
189 changes: 187 additions & 2 deletions core/provider.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package core

import (
"encoding/json"
"fmt"
"reflect"

cmap "github.com/orcaman/concurrent-map/v2"
"github.com/vmihailenco/msgpack/v5"
)

const (
Expand All @@ -13,12 +15,16 @@ const (
)

var (
globalProvider = &provider{
globalProvider = newProvider()
)

func newProvider() *provider {
return &provider{
type2NameMap: cmap.NewStringer[reflect.Type, string](),
name2TypeMap: cmap.New[reflect.Type](),
marshalTypeMap: cmap.NewStringer[reflect.Type, reflect.Type](),
}
)
}

type provider struct {
type2NameMap cmap.ConcurrentMap[reflect.Type, string]
Expand Down Expand Up @@ -67,3 +73,182 @@ func (p *provider) setBidNameType(name string, t reflect.Type) error {
}
return nil
}

func (p *provider) RegisterType(name string, id any) error {
val := reflect.ValueOf(id)
typ := val.Type()
if err := p.setBidNameType(name, typ); err != nil {
return err
}
return nil
}

func (p *provider) MarshalJson(in any) ([]byte, error) {
outIf, err := p.shuttleMarshal(in)
if err != nil {
return nil, err
}
return json.Marshal(outIf)
}

func (p *provider) UnmarshalJson(out any, raw json.RawMessage) error {
return p.shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
return json.Unmarshal(raw, a)
})

}

func (p *provider) MarshalMsgPack(in any) ([]byte, error) {
outIf, err := p.shuttleMarshal(in)
if err != nil {
return nil, err
}
return msgpack.Marshal(outIf)
}

func (p *provider) UnmarshalMsgPack(out any, b []byte) error {
return p.shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
return msgpack.Unmarshal(b, a)
})

}

func (p *provider) MarshalYaml(in any) (any, error) {
return p.shuttleMarshal(in)
}

func (p *provider) UnmarshalYaml(out any, unmarshal func(any) error) error {

return p.shuttleUnmarshal(out, func(a any, inType, outTyp reflect.Type) error {
if err := unmarshal(a); err != nil {
return ReplaceYAMLTypeError(err, inType, outTyp)
}
return nil
})

}

func (p *provider) shuttleMarshal(in any) (any, error) {
inVal := reflect.ValueOf(in)
inVal = RevealValue(inVal)
inType := inVal.Type()
outType := p.getMarshalType(inType)
outPtr := reflect.New(outType)
outVal := outPtr.Elem()
for i, n := 0, inType.NumField(); i < n; i++ {
inField := inType.Field(i)
if !inField.IsExported() {
continue
}
oriVal := inVal.Field(i)
if !needWrap(inField) {
outVal.Field(i).Set(oriVal)
continue
}

fieldType := oriVal.Type()
if fieldType.Kind() == reflect.Slice {
newSliceVal := reflect.MakeSlice(WrapperSliceType, oriVal.Len(), oriVal.Len())
for i, n := 0, oriVal.Len(); i < n; i++ {
wrappedVal, err := p.wrapValue(oriVal.Index(i))
if err != nil {
return nil, err
}

newSliceVal.Index(i).Set(wrappedVal)
}
outVal.Field(i).Set(newSliceVal)
continue
}

wrappedVal, err := p.wrapValue(oriVal)
if err != nil {
return nil, err
}
outVal.Field(i).Set(wrappedVal)
continue
}
outIf := outVal.Interface()
return outIf, nil
}

func (p *provider) shuttleUnmarshal(out any, unmarshal func(any, reflect.Type, reflect.Type) error) error {
outVal := reflect.ValueOf(out)
if outVal.Kind() != reflect.Ptr {
return fmt.Errorf("discovery: can only unmarshal into a struct pointer: %T", out)
}
outVal = RevealValue(outVal)
if outVal.Kind() != reflect.Struct {
return fmt.Errorf("discovery: can only unmarshal into a struct pointer: %T", out)
}
outTyp := outVal.Type()
inType := p.getMarshalType(outTyp)
inPtr := reflect.New(inType)
inVal := inPtr.Elem()

if err := unmarshal(inPtr.Interface(), inType, outTyp); err != nil {
return err
}

for i, n := 0, inType.NumField(); i < n; i++ {
inField := inType.Field(i)
if !inField.IsExported() {
continue
}

inValIdxi := inVal.Field(i)
fieldTypeIdxi := inValIdxi.Type()
fieldTypeIdxi = RevealType(fieldTypeIdxi)
if fieldTypeIdxi == WrapperSliceType {
outValIdxi := outVal.Field(i)
outValTypeIdxi := outValIdxi.Type()

cvTyp := fieldTypeIdxi.Elem()
if cvTyp != WrapperType {
outVal.Field(i).Set(inValIdxi)
continue
}
if inValIdxi.Len() == 0 {
continue
}
newSliceVal := reflect.MakeSlice(outValTypeIdxi, inValIdxi.Len(), inValIdxi.Len())
for i, n := 0, inValIdxi.Len(); i < n; i++ {

val := inValIdxi.Index(i)
field1 := val.Field(1)
field1 = RevealInterface(field1)
newSliceVal.Index(i).Set(field1)
}
outVal.Field(i).Set(newSliceVal)
continue
}

if fieldTypeIdxi == WrapperType {
val := inValIdxi

field1 := val.Field(1)
field1 = RevealInterface(field1)
outVal.Field(i).Set(field1)

continue
}
outVal.Field(i).Set(inValIdxi)
}
return nil
}

func (p *provider) wrapValue(oriVal reflect.Value) (reflect.Value, error) {

val := oriVal
typ := RevealInterface(val).Type()
name, err := p.getNameByType(typ)
if err != nil {
return emptyValue, err
}

wrapper := Wrapper{
Kind: name,
Value: oriVal.Interface(),
}
return reflect.ValueOf(wrapper), nil
}
Loading

0 comments on commit f089fb7

Please sign in to comment.