From 4214131ed726fa5c2ec3f1d943ac4e696ed523fe Mon Sep 17 00:00:00 2001 From: Dean Karn Date: Sun, 28 May 2023 09:06:26 -0700 Subject: [PATCH] Handle ptr smart ptrs (#44) - Added new default & set types for `chan`, `map`, `slice`, `time.Time` and `Pointer` types. - Updated to handle pointer to a `Slice` or `Map` for situations where code is not under your control. Fixes #43 --- README.md | 9 +- modifiers/multi.go | 53 +++++++++++ modifiers/multi_test.go | 196 ++++++++++++++++++++++++++++++++++++++++ mold.go | 10 ++ 4 files changed, 267 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b4d68dc..4ebde53 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ Package mold ============ -![Project status](https://img.shields.io/badge/version-4.4.0-green.svg) +![Project status](https://img.shields.io/badge/version-4.5.0-green.svg) [![Build Status](https://travis-ci.org/go-playground/mold.svg?branch=v2)](https://travis-ci.org/go-playground/mold) [![Coverage Status](https://coveralls.io/repos/github/go-playground/mold/badge.svg?branch=v2)](https://coveralls.io/github/go-playground/mold?branch=v2) [![Go Report Card](https://goreportcard.com/badge/github.com/go-playground/mold)](https://goreportcard.com/report/github.com/go-playground/mold) @@ -58,7 +58,14 @@ These functions modify the data in-place. | ucase | Uppercases the data. | | ucfirst | Upper cases the first character of the data. | +**Special Notes:** +`default` and `set` modifiers are special in that they can be used to set the value of a field or underlying type information or attributes and both use the same underlying function to set the data. +Setting a Param will have the following special effects on data types where it's not just the value being set: +- Chan - param used to set the buffer size, default = 0. +- Slice - param used to set the capacity, default = 0. +- Map - param used to set the size, default = 0. +- time.Time - param used to set the time format OR value, default = time.Now(), `utc` = time.Now().UTC(), other tries to parse using RFC3339Nano and set a time value. Scrubbers ---------- diff --git a/modifiers/multi.go b/modifiers/multi.go index 2a35ece..9320f00 100644 --- a/modifiers/multi.go +++ b/modifiers/multi.go @@ -4,6 +4,7 @@ import ( "context" "reflect" "strconv" + "strings" "time" "github.com/go-playground/mold/v4" @@ -11,6 +12,7 @@ import ( var ( durationType = reflect.TypeOf(time.Duration(0)) + timeType = reflect.TypeOf(time.Time{}) ) // defaultValue allows setting of a default value IF no value is already present. @@ -73,6 +75,57 @@ func setValue(ctx context.Context, fl mold.FieldLevel) error { } fl.Field().SetBool(value) + case reflect.Map: + var n int + var err error + if fl.Param() != "" { + n, err = strconv.Atoi(fl.Param()) + if err != nil { + return err + } + } + fl.Field().Set(reflect.MakeMapWithSize(fl.Field().Type(), n)) + + case reflect.Slice: + var cap int + var err error + if fl.Param() != "" { + cap, err = strconv.Atoi(fl.Param()) + if err != nil { + return err + } + } + fl.Field().Set(reflect.MakeSlice(fl.Field().Type(), 0, cap)) + + case reflect.Struct: + if fl.Field().Type() == timeType { + if fl.Param() != "" { + if strings.ToLower(fl.Param()) == "utc" { + fl.Field().Set(reflect.ValueOf(time.Now().UTC())) + } else { + t, err := time.Parse(time.RFC3339Nano, fl.Param()) + if err != nil { + return err + } + fl.Field().Set(reflect.ValueOf(t)) + } + } else { + fl.Field().Set(reflect.ValueOf(time.Now())) + } + } + case reflect.Chan: + var buffer int + var err error + if fl.Param() != "" { + buffer, err = strconv.Atoi(fl.Param()) + if err != nil { + return err + } + } + fl.Field().Set(reflect.MakeChan(fl.Field().Type(), buffer)) + + case reflect.Ptr: + fl.Field().Set(reflect.New(fl.Field().Type().Elem())) } return nil } diff --git a/modifiers/multi_test.go b/modifiers/multi_test.go index 572854b..201a46c 100644 --- a/modifiers/multi_test.go +++ b/modifiers/multi_test.go @@ -8,6 +8,202 @@ import ( . "github.com/go-playground/assert/v2" ) +func TestDefaultSetSpecialTypes(t *testing.T) { + conform := New() + + tests := []struct { + name string + field interface{} + tags string + vf func(field interface{}) + expectError bool + }{ + { + name: "default map", + field: (map[string]struct{})(nil), + tags: "default", + vf: func(field interface{}) { + m := field.(map[string]struct{}) + Equal(t, len(m), 0) + }, + }, + { + name: "default map with size", + field: (map[string]struct{})(nil), + tags: "default=5", + vf: func(field interface{}) { + m := field.(map[string]struct{}) + Equal(t, len(m), 0) + }, + }, + { + name: "set map with size", + field: (map[string]struct{})(nil), + tags: "set=5", + vf: func(field interface{}) { + m := field.(map[string]struct{}) + Equal(t, len(m), 0) + }, + }, + { + name: "default slice", + field: ([]string)(nil), + tags: "default", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + Equal(t, cap(m), 0) + }, + }, + { + name: "default slice with capacity", + field: ([]string)(nil), + tags: "default=5", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + Equal(t, cap(m), 5) + }, + }, + { + name: "set slice", + field: ([]string)(nil), + tags: "set", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + Equal(t, cap(m), 0) + }, + }, + { + name: "set slice with capacity", + field: ([]string)(nil), + tags: "set=5", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + Equal(t, cap(m), 5) + }, + }, + { + name: "default chan", + field: (chan struct{})(nil), + tags: "default", + vf: func(field interface{}) { + m := field.(chan struct{}) + Equal(t, len(m), 0) + Equal(t, cap(m), 0) + }, + }, + { + name: "default chan with buffer", + field: (chan struct{})(nil), + tags: "default=5", + vf: func(field interface{}) { + m := field.(chan struct{}) + Equal(t, len(m), 0) + Equal(t, cap(m), 5) + }, + }, + { + name: "default time.Time", + field: time.Time{}, + tags: "default", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.Local) + }, + }, + { + name: "default time.Time utc", + field: time.Time{}, + tags: "default=utc", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.UTC) + }, + }, + { + name: "default time.Time to value", + field: time.Time{}, + tags: "default=2023-05-28T15:50:31Z", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.UTC) + + tm, err := time.Parse(time.RFC3339Nano, "2023-05-28T15:50:31Z") + Equal(t, err, nil) + Equal(t, tm.Equal(m), true) + + }, + }, + { + name: "set time.Time", + field: time.Time{}, + tags: "set", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.Local) + }, + }, + { + name: "set time.Time utc", + field: time.Time{}, + tags: "set=utc", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.UTC) + }, + }, + { + name: "set time.Time to value", + field: time.Time{}, + tags: "set=2023-05-28T15:50:31Z", + vf: func(field interface{}) { + m := field.(time.Time) + Equal(t, m.Location(), time.UTC) + + tm, err := time.Parse(time.RFC3339Nano, "2023-05-28T15:50:31Z") + Equal(t, err, nil) + Equal(t, tm.Equal(m), true) + + }, + }, + { + name: "default pointer to slice", + field: (*[]string)(nil), + tags: "default", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + }, + }, + { + name: "set pointer to slice", + field: (*[]string)(nil), + tags: "set", + vf: func(field interface{}) { + m := field.([]string) + Equal(t, len(m), 0) + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := conform.Field(context.Background(), &tc.field, tc.tags) + if tc.expectError { + NotEqual(t, err, nil) + return + } + Equal(t, err, nil) + tc.vf(tc.field) + }) + } +} + func TestSet(t *testing.T) { type State int diff --git a/mold.go b/mold.go index bb389e6..dee4ee6 100644 --- a/mold.go +++ b/mold.go @@ -239,6 +239,14 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cT err = t.setByIterable(ctx, current, ct) case reflect.Map: err = t.setByMap(ctx, current, ct) + case reflect.Ptr: + innerKind := current.Type().Elem().Kind() + if innerKind == reflect.Slice || innerKind == reflect.Map { + // is a nil pointer to a slice or map, nothing to do. + return nil + } + // not a valid use of the dive tag + fallthrough default: err = ErrInvalidDive } @@ -267,6 +275,8 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cT }); err != nil { return } + // value could have been changed or reassigned + current, kind = t.extractType(current) } ct = ct.next }