diff --git a/decode.go b/decode.go index 6b15618d..d3dbabcb 100644 --- a/decode.go +++ b/decode.go @@ -29,6 +29,7 @@ type Decoder struct { referenceReaders []io.Reader anchorNodeMap map[string]ast.Node anchorValueMap map[string]reflect.Value + customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error toCommentMap CommentMap opts []DecodeOption referenceFiles []string @@ -50,6 +51,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder { reader: r, anchorNodeMap: map[string]ast.Node{}, anchorValueMap: map[string]reflect.Value{}, + customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{}, opts: opts, referenceReaders: []io.Reader{}, referenceFiles: []string{}, @@ -638,8 +640,38 @@ type jsonUnmarshaler interface { UnmarshalJSON([]byte) error } +func (d *Decoder) existsTypeInCustomUnmarshalerMap(t reflect.Type) bool { + if _, exists := d.customUnmarshalerMap[t]; exists { + return true + } + + globalCustomUnmarshalerMu.Lock() + defer globalCustomUnmarshalerMu.Unlock() + if _, exists := globalCustomUnmarshalerMap[t]; exists { + return true + } + return false +} + +func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(interface{}, []byte) error, bool) { + if unmarshaler, exists := d.customUnmarshalerMap[t]; exists { + return unmarshaler, exists + } + + globalCustomUnmarshalerMu.Lock() + defer globalCustomUnmarshalerMu.Unlock() + if unmarshaler, exists := globalCustomUnmarshalerMap[t]; exists { + return unmarshaler, exists + } + return nil, false +} + func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool { - iface := dst.Addr().Interface() + ptrValue := dst.Addr() + if d.existsTypeInCustomUnmarshalerMap(ptrValue.Type()) { + return true + } + iface := ptrValue.Interface() switch iface.(type) { case BytesUnmarshalerContext: return true @@ -662,7 +694,18 @@ func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool { } func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, src ast.Node) error { - iface := dst.Addr().Interface() + ptrValue := dst.Addr() + if unmarshaler, exists := d.unmarshalerFromCustomUnmarshalerMap(ptrValue.Type()); exists { + b, err := d.unmarshalableDocument(src) + if err != nil { + return errors.Wrapf(err, "failed to UnmarshalYAML") + } + if err := unmarshaler(ptrValue.Interface(), b); err != nil { + return errors.Wrapf(err, "failed to UnmarshalYAML") + } + return nil + } + iface := ptrValue.Interface() if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok { b, err := d.unmarshalableDocument(src) diff --git a/decode_test.go b/decode_test.go index 3938097d..2549bea0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1858,6 +1858,54 @@ func TestDecoder_UseJSONUnmarshaler(t *testing.T) { } } +func TestDecoder_CustomUnmarshaler(t *testing.T) { + t.Run("override struct type", func(t *testing.T) { + type T struct { + Foo string `yaml:"foo"` + } + src := []byte(`foo: "bar"`) + var v T + if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[T](func(dst *T, b []byte) error { + if !bytes.Equal(src, b) { + t.Fatalf("failed to get decode target buffer. expected %q but got %q", src, b) + } + var v T + if err := yaml.Unmarshal(b, &v); err != nil { + return err + } + if v.Foo != "bar" { + t.Fatal("failed to decode") + } + dst.Foo = "bazbaz" // assign another value to target + return nil + })); err != nil { + t.Fatal(err) + } + if v.Foo != "bazbaz" { + t.Fatalf("failed to switch to custom unmarshaler. got: %v", v.Foo) + } + }) + t.Run("override bytes type", func(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + src := []byte(`foo: "bar"`) + var v T + if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[[]byte](func(dst *[]byte, b []byte) error { + if !bytes.Equal(b, []byte(`"bar"`)) { + t.Fatalf("failed to get target buffer: %q", b) + } + *dst = []byte("bazbaz") + return nil + })); err != nil { + t.Fatal(err) + } + if !bytes.Equal(v.Foo, []byte("bazbaz")) { + t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo) + } + }) +} + type unmarshalContext struct { v int } diff --git a/encode.go b/encode.go index 4543138b..7d8d81e0 100644 --- a/encode.go +++ b/encode.go @@ -37,6 +37,7 @@ type Encoder struct { useJSONMarshaler bool anchorCallback func(*ast.AnchorNode, interface{}) error anchorPtrToNameMap map[uintptr]string + customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error) useLiteralStyleIfMultiline bool commentMap map[*Path][]*Comment written bool @@ -56,6 +57,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder { opts: opts, indent: DefaultIndentSpaces, anchorPtrToNameMap: map[uintptr]string{}, + customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){}, line: 1, column: 1, offset: 0, @@ -273,10 +275,39 @@ type jsonMarshaler interface { MarshalJSON() ([]byte, error) } +func (e *Encoder) existsTypeInCustomMarshalerMap(t reflect.Type) bool { + if _, exists := e.customMarshalerMap[t]; exists { + return true + } + + globalCustomMarshalerMu.Lock() + defer globalCustomMarshalerMu.Unlock() + if _, exists := globalCustomMarshalerMap[t]; exists { + return true + } + return false +} + +func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interface{}) ([]byte, error), bool) { + if marshaler, exists := e.customMarshalerMap[t]; exists { + return marshaler, exists + } + + globalCustomMarshalerMu.Lock() + defer globalCustomMarshalerMu.Unlock() + if marshaler, exists := globalCustomMarshalerMap[t]; exists { + return marshaler, exists + } + return nil, false +} + func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool { if !v.CanInterface() { return false } + if e.existsTypeInCustomMarshalerMap(v.Type()) { + return true + } iface := v.Interface() switch iface.(type) { case BytesMarshalerContext: @@ -302,6 +333,18 @@ func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool { func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column int) (ast.Node, error) { iface := v.Interface() + if marshaler, exists := e.marshalerFromCustomMarshalerMap(v.Type()); exists { + doc, err := marshaler(iface) + if err != nil { + return nil, errors.Wrapf(err, "failed to MarshalYAML") + } + node, err := e.encodeDocument(doc) + if err != nil { + return nil, errors.Wrapf(err, "failed to encode document") + } + return node, nil + } + if marshaler, ok := iface.(BytesMarshalerContext); ok { doc, err := marshaler.MarshalYAML(ctx) if err != nil { diff --git a/encode_test.go b/encode_test.go index bdb667a0..74b1aa59 100644 --- a/encode_test.go +++ b/encode_test.go @@ -4,13 +4,14 @@ import ( "bytes" "context" "fmt" - "github.com/goccy/go-yaml/parser" "math" "reflect" "strconv" "testing" "time" + "github.com/goccy/go-yaml/parser" + "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" ) @@ -1177,6 +1178,40 @@ a: } } +func TestEncoder_CustomMarshaler(t *testing.T) { + t.Run("override struct type", func(t *testing.T) { + type T struct { + Foo string `yaml:"foo"` + } + b, err := yaml.MarshalWithOptions(&T{Foo: "bar"}, yaml.CustomMarshaler[T](func(v T) ([]byte, error) { + return []byte(`"override"`), nil + })) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("\"override\"\n")) { + t.Fatalf("failed to switch to custom marshaler. got: %q", b) + } + }) + t.Run("override bytes type", func(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + b, err := yaml.MarshalWithOptions(&T{Foo: []byte("bar")}, yaml.CustomMarshaler[[]byte](func(v []byte) ([]byte, error) { + if !bytes.Equal(v, []byte("bar")) { + t.Fatalf("failed to get src buffer: %q", v) + } + return []byte(`override`), nil + })) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("foo: override\n")) { + t.Fatalf("failed to switch to custom marshaler. got: %q", b) + } + }) +} + func TestEncoder_MultipleDocuments(t *testing.T) { var buf bytes.Buffer enc := yaml.NewEncoder(&buf) diff --git a/option.go b/option.go index 122a63dd..eab5d43a 100644 --- a/option.go +++ b/option.go @@ -2,6 +2,7 @@ package yaml import ( "io" + "reflect" "github.com/goccy/go-yaml/ast" ) @@ -94,6 +95,20 @@ func UseJSONUnmarshaler() DecodeOption { } } +// CustomUnmarshaler overrides any decoding process for the type specified in generics. +// +// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type, +// the CustomUnmarshaler specified in DecodeOption takes precedence. +func CustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) DecodeOption { + return func(d *Decoder) error { + var typ *T + d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error { + return unmarshaler(v.(*T), b) + } + return nil + } +} + // EncodeOption functional option type for Encoder type EncodeOption func(e *Encoder) error @@ -165,6 +180,21 @@ func UseJSONMarshaler() EncodeOption { } } +// CustomMarshaler overrides any encoding process for the type specified in generics. +// +// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in CustomMarshaler must be *T. +// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type, +// the CustomMarshaler specified in EncodeOption takes precedence. +func CustomMarshaler[T any](marshaler func(T) ([]byte, error)) EncodeOption { + return func(e *Encoder) error { + var typ T + e.customMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) { + return marshaler(v.(T)) + } + return nil + } +} + // CommentPosition type of the position for comment. type CommentPosition int diff --git a/yaml.go b/yaml.go index 6074f9d7..2e541d85 100644 --- a/yaml.go +++ b/yaml.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "io" + "reflect" + "sync" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" @@ -248,3 +250,41 @@ func JSONToYAML(bytes []byte) ([]byte, error) { } return out, nil } + +var ( + globalCustomMarshalerMu sync.Mutex + globalCustomUnmarshalerMu sync.Mutex + globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){} + globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{} +) + +// RegisterCustomMarshaler overrides any encoding process for the type specified in generics. +// If you want to switch the behavior for each encoder, use `CustomMarshaler` defined as EncodeOption. +// +// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in RegisterCustomMarshaler must be *T. +// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type, +// the CustomMarshaler specified in EncodeOption takes precedence. +func RegisterCustomMarshaler[T any](marshaler func(T) ([]byte, error)) { + globalCustomMarshalerMu.Lock() + defer globalCustomMarshalerMu.Unlock() + + var typ T + globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) { + return marshaler(v.(T)) + } +} + +// RegisterCustomUnmarshaler overrides any decoding process for the type specified in generics. +// If you want to switch the behavior for each decoder, use `CustomUnmarshaler` defined as DecodeOption. +// +// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type, +// the CustomUnmarshaler specified in DecodeOption takes precedence. +func RegisterCustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) { + globalCustomUnmarshalerMu.Lock() + defer globalCustomUnmarshalerMu.Unlock() + + var typ *T + globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error { + return unmarshaler(v.(*T), b) + } +} diff --git a/yaml_test.go b/yaml_test.go index dfe6df1a..5828629c 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -1160,3 +1160,36 @@ hoge: } }) } + +func TestRegisterCustomMarshaler(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + yaml.RegisterCustomMarshaler[T](func(_ T) ([]byte, error) { + return []byte(`"override"`), nil + }) + b, err := yaml.Marshal(&T{Foo: []byte("bar")}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b, []byte("\"override\"\n")) { + t.Fatalf("failed to register custom marshaler. got: %q", b) + } +} + +func TestRegisterCustomUnmarshaler(t *testing.T) { + type T struct { + Foo []byte `yaml:"foo"` + } + yaml.RegisterCustomUnmarshaler[T](func(v *T, _ []byte) error { + v.Foo = []byte("override") + return nil + }) + var v T + if err := yaml.Unmarshal([]byte(`"foo: "bar"`), &v); err != nil { + t.Fatal(err) + } + if !bytes.Equal(v.Foo, []byte("override")) { + t.Fatalf("failed to decode. got %q", v.Foo) + } +}