diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 2b1734ec..2ffc5fdc 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,6 @@ # 开发中 + +# v0.0.8 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) - [queue: API 定义](https://github.com/gotomicro/ekit/pull/109) - [queue: 基于堆和切片的优先级队列](https://github.com/gotomicro/ekit/pull/110) @@ -7,6 +9,33 @@ - [queue: 基于切片的并发阻塞队列和基于 CAS 的并发队列设计](https://github.com/gotomicro/ekit/pull/119) - [queue: 基于链表实现的有界/无界阻塞队列](https://github.com/gotomicro/ekit/pull/122) - [queue: ConcurrentLinkBlockingQueue重命名为ConcurrentLinkedBlockingQueue](https://github.com/gotomicro/ekit/pull/123) +- [syncx: sync.Cond的超时等待版,Cond.WaitWithContext(ctx)](https://github.com/ecodeclub/ekit/pull/192) +- [copier: ReflectCopier copier支持类型转换](https://github.com/ecodeclub/ekit/issues/197) +- [mapx: TreeMap 添加 Keys 和 Values 方法](https://github.com/ecodeclub/ekit/pull/181) +- [mapx: 修正 HashMap 中使用泛型不当的地方](https://github.com/ecodeclub/ekit/pull/186) +- [mapx: 支持 builtinMap,用于接入其它装饰器实现](https://github.com/ecodeclub/ekit/pull/202) +- [pool: 重构TaskPool测试用例](https://github.com/ecodeclub/ekit/pull/178) +- [sqlx:ScanRows 和 ScanAll方法](https://github.com/ecodeclub/ekit/pull/180) +- [mapx: 修复红黑树删除节点问题](https://github.com/ecodeclub/ekit/pull/183) +- [sqlx: 构建Scanner抽象替代现有ScanRows及ScanAll](https://github.com/ecodeclub/ekit/pull/182) +- [sqlx: 预定义 Rows 接口](https://github.com/ecodeclub/ekit/pull/209) +- [pool: 重构TaskPool](https://github.com/ecodeclub/ekit/pull/184) +- [syncx:Map 支持 LoadOrStoreFunc 方法](https://github.com/ecodeclub/ekit/pull/194) +- [mapx: MutipleTreeMap](https://github.com/ecodeclub/ekit/pull/187) +- [mapx: 为 MultipleMap 添加 PutVals 方法](https://github.com/ecodeclub/ekit/pull/189) +- [mapx: LinkedMap 特性](https://github.com/ecodeclub/ekit/pull/191) +- [copier: ReflectCopier 支持忽略字段](https://github.com/ecodeclub/ekit/pull/196) +- [syncx: 重构LoadOrStoreFunc方法及相关测试](https://github.com/ecodeclub/ekit/pull/198) +- [slice: 添加Add函数,在指定位置插入元素](https://github.com/ecodeclub/ekit/pull/201) +- [slice: 优化delete方法,无需从头开始遍历](https://github.com/ecodeclub/ekit/pull/203) +- [slice: 重构 slice 中使用 equalFunc 的方法](https://github.com/ecodeclub/ekit/pull/205) +- [randx: 新增生成随机code方法](https://github.com/ecodeclub/ekit/pull/207) +- [slice: intersect方法优化, symmetricDiffSet重构](https://github.com/ecodeclub/ekit/pull/208) +- [sqlx: 修复EncryptColumn Scan方法string分支错误](https://github.com/ecodeclub/ekit/pull/211) +- [sqlx: Scanner 添加 NextResultSet 方法](https://github.com/ecodeclub/ekit/pull/212) +- [ekit: AnyValue 支持As[Type]类型 String 转换](https://github.com/ecodeclub/ekit/pull/213) +- [stringx: unsafe 转换 string 和 []byte](https://github.com/ecodeclub/ekit/pull/215) + - [stringx: 添加 Benchmark](https://github.com/ecodeclub/ekit/pull/216) # v0.0.7 - [slice: FilterDelete](https://github.com/ecodeclub/ekit/pull/152) diff --git a/.gitignore b/.gitignore index 13da0af5..b12cde5d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,5 @@ # vendor/ .idea - +.vscode **/.DS_Store \ No newline at end of file diff --git a/.imgs/contact_me_qr.jpg b/.imgs/contact_me_qr.jpg new file mode 100644 index 00000000..2e59f446 Binary files /dev/null and b/.imgs/contact_me_qr.jpg differ diff --git a/README.md b/README.md index 88a7a3ac..aecce8a5 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,13 @@ 泛型工具库。 - [文档](https://ekit.gocn.vip/ekit/develop/guide/) + +## 交流 + +交流群。原本我是觉得有一个群会削弱 github 的社区氛围,但是比较多人还是习惯于用群交流,所以我也搞了一个。 + +但是希望你进群之前要先想好,这个群并不希望大家讨论任何的社会议题,包括政治、历史、男女、情感等。我们希望这个群承担的功能是讨论技术问题和技术互助。 + +技术互助的意思是,你进群是希望有人来帮你解答问题;那么同样地,看到别人提问,也希望你能帮助解答。 + +![入群](./.imgs/contact_me_qr.jpg) \ No newline at end of file diff --git a/bean/copier/converter/converter.go b/bean/copier/converter/converter.go new file mode 100644 index 00000000..08becfad --- /dev/null +++ b/bean/copier/converter/converter.go @@ -0,0 +1,25 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package converter + +type Converter[Src any, Dst any] interface { + Convert(src Src) (Dst, error) +} + +type ConverterFunc[Src any, Dst any] func(src Src) (Dst, error) + +func (cf ConverterFunc[Src, Dst]) Convert(src Src) (Dst, error) { + return cf(src) +} diff --git a/bean/copier/converter/time2string.go b/bean/copier/converter/time2string.go new file mode 100644 index 00000000..a2d25b4f --- /dev/null +++ b/bean/copier/converter/time2string.go @@ -0,0 +1,25 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package converter + +import "time" + +type Time2String struct { + Pattern string +} + +func (t Time2String) Convert(src time.Time) (string, error) { + return src.Format(t.Pattern), nil +} diff --git a/bean/copier/copy.go b/bean/copier/copy.go index 00d4c046..6ea8048b 100644 --- a/bean/copier/copy.go +++ b/bean/copier/copy.go @@ -14,6 +14,12 @@ package copier +import ( + "github.com/ecodeclub/ekit/bean/copier/converter" + "github.com/ecodeclub/ekit/bean/option" + "github.com/ecodeclub/ekit/set" +) + // Copier 复制数据 // 1. 深拷贝亦或是浅拷贝,取决于具体的实现。每个实现都要声明清楚这一点; // 2. Src 和 Dst 都必须是普通的结构体,支持组合 @@ -21,7 +27,65 @@ package copier // 这种设计设计,即使用 *Src 和 *Dst 可能加剧内存逃逸 type Copier[Src any, Dst any] interface { // CopyTo 将 src 中的数据复制到 dst 中 - CopyTo(src *Src, dst *Dst) error + CopyTo(src *Src, dst *Dst, opts ...option.Option[options]) error // Copy 将创建一个 Dst 的实例,并且将 Src 中的数据复制过去 - Copy(src *Src) (*Dst, error) + Copy(src *Src, opts ...option.Option[options]) (*Dst, error) +} + +// options 执行复制操作时的可选配置 +type options struct { + // ignoreFields 执行复制操作时,需要忽略的字段 + ignoreFields *set.MapSet[string] + // convertFields 执行转换的field和转化接口的泛型包装 + convertFields map[string]converterWrapper +} + +type converterWrapper func(src any) (any, error) + +func newOptions() options { + return options{} +} + +// InIgnoreFields 判断 str 是不是在 ignoreFields 里面 +func (r *options) InIgnoreFields(str string) bool { + // 如果没有设置过忽略的字段的话,ignoreFields 就有可能是 nil,这里需要判断一下 + if r.ignoreFields == nil { + return false + } + return r.ignoreFields.Exist(str) +} + +// IgnoreFields 设置复制时要忽略的字段(option 设计模式) +func IgnoreFields(fields ...string) option.Option[options] { + return func(opt *options) { + if len(fields) < 1 { + return + } + // 需要用的时候再延迟初始化 ignoreFields + if opt.ignoreFields == nil { + opt.ignoreFields = set.NewMapSet[string](len(fields)) + } + for i := 0; i < len(fields); i++ { + opt.ignoreFields.Add(fields[i]) + } + } +} + +func ConvertField[Src any, Dst any](field string, converter converter.Converter[Src, Dst]) option.Option[options] { + return func(opt *options) { + if field == "" || converter == nil { + return + } + if opt.convertFields == nil { + opt.convertFields = make(map[string]converterWrapper, 8) + } + opt.convertFields[field] = func(src any) (any, error) { + var dst Dst + srcVal, ok := src.(Src) + if !ok { + return dst, errConvertFieldTypeNotMatch + } + return converter.Convert(srcVal) + } + } } diff --git a/bean/copier/errors.go b/bean/copier/errors.go index fc9ce62b..b57149c5 100644 --- a/bean/copier/errors.go +++ b/bean/copier/errors.go @@ -15,10 +15,15 @@ package copier import ( + "errors" "fmt" "reflect" ) +var ( + errConvertFieldTypeNotMatch = errors.New("ekit: 转化字段类型不匹配") +) + // newErrTypeError copier 不支持的类型 func newErrTypeError(typ reflect.Type) error { return fmt.Errorf("ekit: copier 入口只支持 Struct 不支持类型 %v, 种类 %v", typ, typ.Kind()) diff --git a/bean/copier/reflect_copier.go b/bean/copier/reflect_copier.go index 2c882664..aca46f31 100644 --- a/bean/copier/reflect_copier.go +++ b/bean/copier/reflect_copier.go @@ -16,14 +16,30 @@ package copier import ( "reflect" + "time" + + "github.com/ecodeclub/ekit/set" + + "github.com/ecodeclub/ekit/bean/option" ) +var defaultAtomicTypes = []reflect.Type{ + reflect.TypeOf(time.Time{}), +} + // ReflectCopier 基于反射的实现 // ReflectCopier 是浅拷贝 type ReflectCopier[Src any, Dst any] struct { // rootField 字典树的根节点 rootField fieldNode + + // options 执行复制操作时的可选配置 + // 如果默认配置和Copy()/CopyTo()中的配置同名,会替换defaultOptions同名内容 + // 初始化时的默认配置,仅作为记录,执行时会拷贝到options中 + defaultOptions options + + atomicTypes []reflect.Type } // fieldNode 字段的前缀树 @@ -45,7 +61,7 @@ type fieldNode struct { } // NewReflectCopier 如果类型不匹配, 创建时直接检查报错. -func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) { +func NewReflectCopier[Src any, Dst any](opts ...option.Option[options]) (*ReflectCopier[Src, Dst], error) { src := new(Src) srcTyp := reflect.TypeOf(src).Elem() dst := new(Dst) @@ -60,18 +76,24 @@ func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) { if dstTyp.Kind() != reflect.Struct { return nil, newErrTypeError(dstTyp) } - if err := createFieldNodes(&root, srcTyp, dstTyp); err != nil { - return nil, err - } copier := &ReflectCopier[Src, Dst]{ - rootField: root, + atomicTypes: defaultAtomicTypes, } + + if err := copier.createFieldNodes(&root, srcTyp, dstTyp); err != nil { + return nil, err + } + copier.rootField = root + + defaultOpts := newOptions() + option.Apply(&defaultOpts, opts...) + copier.defaultOptions = defaultOpts return copier, nil } // createFieldNodes 递归创建 field 的前缀树, srcTyp 和 dstTyp 只能是结构体 -func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error { +func (r *ReflectCopier[Src, Dst]) createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error { fieldMap := map[string]int{} for i := 0; i < srcTyp.NumField(); i++ { @@ -93,17 +115,12 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error { continue } srcFieldTypStruct := srcTyp.Field(srcIndex) - if srcFieldTypStruct.Type.Kind() != dstFieldTypStruct.Type.Kind() { - return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name) - } - if srcFieldTypStruct.Type.Kind() == reflect.Pointer { - if srcFieldTypStruct.Type.Elem().Kind() != dstFieldTypStruct.Type.Elem().Kind() { - return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name) - } - if srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer { - return newErrMultiPointer(dstFieldTypStruct.Name) - } + if srcFieldTypStruct.Type.Kind() == reflect.Pointer && srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer { + return newErrMultiPointer(srcFieldTypStruct.Name) + } + if dstFieldTypStruct.Type.Kind() == reflect.Pointer && dstFieldTypStruct.Type.Elem().Kind() == reflect.Pointer { + return newErrMultiPointer(dstFieldTypStruct.Name) } child := fieldNode{ @@ -118,18 +135,22 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error { fieldDstTyp := dstFieldTypStruct.Type if fieldSrcTyp.Kind() == reflect.Pointer { fieldSrcTyp = fieldSrcTyp.Elem() + } + + if fieldDstTyp.Kind() == reflect.Pointer { fieldDstTyp = fieldDstTyp.Elem() } if isShadowCopyType(fieldSrcTyp.Kind()) { // 内置类型,但不匹配,如别名、map和slice - if fieldSrcTyp != fieldDstTyp { - return newErrTypeNotMatchError(srcFieldTypStruct.Type, dstFieldTypStruct.Type, dstFieldTypStruct.Name) - } // 说明当前节点是叶子节点, 直接拷贝 child.isLeaf = true + } else if r.isAtomicType(fieldSrcTyp) { + // 指定可作为一个整体的类型,不用递归 + // 同上,当当前节点是叶子节点时, 直接拷贝 + child.isLeaf = true } else if fieldSrcTyp.Kind() == reflect.Struct { - if err := createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil { + if err := r.createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil { return err } } else { @@ -142,9 +163,9 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error { return nil } -func (r *ReflectCopier[Src, Dst]) Copy(src *Src) (*Dst, error) { +func (r *ReflectCopier[Src, Dst]) Copy(src *Src, opts ...option.Option[options]) (*Dst, error) { dst := new(Dst) - err := r.CopyTo(src, dst) + err := r.CopyTo(src, dst, opts...) return dst, err } @@ -154,55 +175,130 @@ func (r *ReflectCopier[Src, Dst]) Copy(src *Src) (*Dst, error) { // 2. 如果 Src 和 Dst 中匹配的字段,其类型是基本类型(及其指针)或者内置类型(及其指针),并且类型一样,则直接用 Src 的值 // 3. 如果 Src 和 Dst 中匹配的字段,其类型都是结构体,或者都是结构体指针,则会深入复制 // 4. 否则,忽略字段 -func (r *ReflectCopier[Src, Dst]) CopyTo(src *Src, dst *Dst) error { - return r.copyToWithTree(src, dst) +func (r *ReflectCopier[Src, Dst]) CopyTo(src *Src, dst *Dst, opts ...option.Option[options]) error { + localOption := r.copyDefaultOptions() + option.Apply(&localOption, opts...) + return r.copyToWithTree(src, dst, localOption) +} + +// copyDefaultOptions 复制默认配置 +func (r *ReflectCopier[Src, Dst]) copyDefaultOptions() options { + localOption := newOptions() + // 复制ignoreFields default配置 + if r.defaultOptions.ignoreFields != nil { + ignoreFields := set.NewMapSet[string](8) + for _, key := range r.defaultOptions.ignoreFields.Keys() { + ignoreFields.Add(key) + } + localOption.ignoreFields = ignoreFields + } + + // 复制convertFields default配置 + for field, convert := range r.defaultOptions.convertFields { + if localOption.convertFields == nil { + localOption.convertFields = make(map[string]converterWrapper, 8) + } + localOption.convertFields[field] = convert + } + return localOption } -func (r *ReflectCopier[Src, Dst]) copyToWithTree(src *Src, dst *Dst) error { +func (r *ReflectCopier[Src, Dst]) copyToWithTree(src *Src, dst *Dst, opts options) error { srcTyp := reflect.TypeOf(src) dstTyp := reflect.TypeOf(dst) srcValue := reflect.ValueOf(src) dstValue := reflect.ValueOf(dst) - return r.copyTreeNode(srcTyp, srcValue, dstTyp, dstValue, &r.rootField) + return r.copyTreeNode(srcTyp, srcValue, dstTyp, dstValue, &r.rootField, opts) } -func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue reflect.Value, dstType reflect.Type, dstValue reflect.Value, root *fieldNode) error { +func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue reflect.Value, + dstType reflect.Type, dstValue reflect.Value, root *fieldNode, opts options) error { + originSrcVal := srcValue + originDstVal := dstValue if srcValue.Kind() == reflect.Pointer { if srcValue.IsNil() { return nil } - if dstValue.IsNil() { - dstValue.Set(reflect.New(dstType.Elem())) - } srcValue = srcValue.Elem() srcTyp = srcTyp.Elem() + } + if dstValue.Kind() == reflect.Pointer { + if dstValue.IsNil() { + dstValue.Set(reflect.New(dstType.Elem())) + } dstValue = dstValue.Elem() dstType = dstType.Elem() } + // 执行拷贝 if root.isLeaf { - if dstValue.CanSet() { + convert, ok := opts.convertFields[root.name] + if !dstValue.CanSet() { + return nil + } + // 获取convert失败,就需要检测类型是否匹配,类型匹配就直接set + if !ok { + if srcTyp != dstType { + return newErrTypeNotMatchError(srcTyp, dstType, root.name) + } + if srcValue.IsZero() { + return nil + } dstValue.Set(srcValue) + return nil + } + + // 字段执行转换函数时,需要用到原始类型进行判断,set的时候也是根据原始value设置 + if !originDstVal.CanSet() { + return nil } + srcConv, err := convert(originSrcVal.Interface()) + if err != nil { + return err + } + + srcConvType := reflect.TypeOf(srcConv) + srcConvVal := reflect.ValueOf(srcConv) + // 待设置的value和转换获取的value类型不匹配 + if srcConvType != originDstVal.Type() { + return newErrTypeNotMatchError(srcConvType, originDstVal.Type(), root.name) + } + + originDstVal.Set(srcConvVal) return nil } for i := range root.fields { child := &root.fields[i] + + // 只要结构体属性的名字在需要忽略的字段里面,就不走下面的复制逻辑 + if opts.InIgnoreFields(child.name) { + continue + } + childSrcTyp := srcTyp.Field(child.srcIndex) childSrcValue := srcValue.Field(child.srcIndex) childDstTyp := dstType.Field(child.dstIndex) childDstValue := dstValue.Field(child.dstIndex) - if err := r.copyTreeNode(childSrcTyp.Type, childSrcValue, childDstTyp.Type, childDstValue, child); err != nil { + if err := r.copyTreeNode(childSrcTyp.Type, childSrcValue, childDstTyp.Type, childDstValue, child, opts); err != nil { return err } } return nil } +func (r *ReflectCopier[Src, Dst]) isAtomicType(typ reflect.Type) bool { + for _, dt := range r.atomicTypes { + if dt == typ { + return true + } + } + return false +} + func isShadowCopyType(kind reflect.Kind) bool { switch kind { case reflect.Bool, diff --git a/bean/copier/reflect_copier_test.go b/bean/copier/reflect_copier_test.go index d4d9d888..cef83977 100644 --- a/bean/copier/reflect_copier_test.go +++ b/bean/copier/reflect_copier_test.go @@ -15,14 +15,21 @@ package copier import ( + "fmt" "reflect" + "strconv" + "sync" "testing" + "time" + + "github.com/ecodeclub/ekit/bean/copier/converter" "github.com/ecodeclub/ekit" "github.com/stretchr/testify/assert" ) func TestReflectCopier_Copy(t *testing.T) { + t.Parallel() testCases := []struct { name string copyFunc func() (any, error) @@ -267,7 +274,7 @@ func TestReflectCopier_Copy(t *testing.T) { S: struct{ A string }{A: "a"}, }) }, - wantErr: newErrKindNotMatchError(reflect.String, reflect.Int, "A"), + wantErr: newErrTypeNotMatchError(reflect.TypeOf(""), reflect.TypeOf(0), "A"), }, { name: "多重指针", @@ -505,165 +512,1057 @@ func TestReflectCopier_Copy(t *testing.T) { }, wantDst: &SpecialDst2{A: 1}, }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res, err := tc.copyFunc() - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantDst, res) - }) - } -} - -type BasicSrc struct { - Name string - Age int - CNumber complex64 -} - -type BasicDst struct { - Name string - Age int - CNumber complex64 -} - -type SimpleSrc struct { - Name string - Age *int - Friends []string -} - -type SimpleDst struct { - Name string - Age *int - Friends []string -} - -type EmbedSrc struct { - SimpleSrc - *BasicSrc -} - -type EmbedDst struct { - SimpleSrc - *BasicSrc -} - -type ComplexSrc struct { - Simple SimpleSrc - Embed *EmbedSrc - BasicSrc -} - -type ComplexDst struct { - Simple SimpleDst - Embed *EmbedDst - BasicSrc -} - -type SpecialSrc struct { - Arr [3]float32 - M map[string]int -} - -type SpecialDst struct { - Arr [3]float32 - M map[string]int -} - -type InterfaceSrc interface { -} - -type InterfaceDst interface { -} - -type NotMatchSrc struct { - Simple SimpleSrc - Embed *EmbedSrc - BasicSrc - S struct { - A string - } -} - -type NotMatchDst struct { - Simple SimpleDst - Embed *EmbedDst - BasicSrc - S struct { - A int - } -} - -type MultiPtrSrc struct { - Name string - Age **int - Friends []string -} - -type MultiPtrDst struct { - Name string - Age **int - Friends []string -} - -type DiffSrc struct { - A string - B int - c SimpleSrc - F BasicSrc -} -type DiffDst struct { - A string - B int - d SimpleSrc - G BasicSrc -} - -type SimpleEmbedDst struct { - SimpleSrc -} - -type ArraySrc struct { - A []SimpleSrc -} - -type ArrayDst struct { - A []SimpleSrc -} - -type ArrayDst1 struct { - A []SimpleDst -} - -type MapSrc struct { - A map[string]SimpleSrc -} - -type MapDst struct { - A map[string]SimpleSrc -} - -type MapDst1 struct { - A map[string]SimpleDst -} - -type SpecialSrc1 struct { - A int -} - -type aliasInt int -type SpecialDst1 struct { - A aliasInt -} - -type aliasInt1 = int -type SpecialDst2 struct { - A aliasInt1 + { + name: "simple_struct_忽略字段的时候传空", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, IgnoreFields()) + }, + wantDst: &SimpleDst{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "simple_struct_忽略一个字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, IgnoreFields("Age")) + }, + wantDst: &SimpleDst{ + Name: "大明", + Age: nil, + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "simple_struct_忽略多个字段_传入多个Option_每个Option传入一个字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, IgnoreFields("Age"), IgnoreFields("Friends")) + }, + wantDst: &SimpleDst{ + Name: "大明", + Age: nil, + Friends: nil, + }, + }, + { + name: "simple_struct_忽略多个字段_传入一个Option_Option传入多个字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, IgnoreFields("Age", "Friends")) + }, + wantDst: &SimpleDst{ + Name: "大明", + Age: nil, + Friends: nil, + }, + }, + { + name: "simple_struct_忽略全部字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, IgnoreFields("Name"), IgnoreFields("Age"), IgnoreFields("Friends")) + }, + wantDst: &SimpleDst{}, + }, + { + name: "simple_struct_空切片_空指针_忽略字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SimpleSrc{ + Name: "大明", + }, IgnoreFields("Name")) + }, + wantDst: &SimpleDst{ + Name: "", + }, + }, + { + name: "组合_struct_忽略组合中的一个字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[EmbedSrc, EmbedDst]() + if err != nil { + return nil, err + } + return copier.Copy(&EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaoli", + Age: ekit.ToPtr[int](19), + Friends: []string{}, + }, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 20, + CNumber: complex(2, 2), + }, + }, IgnoreFields("CNumber")) + }, + wantDst: &EmbedDst{ + SimpleSrc: SimpleSrc{ + Name: "xiaoli", + Age: ekit.ToPtr[int](19), + Friends: []string{}, + }, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 20, + CNumber: complex(0, 0), + }, + }, + }, + { + name: "组合_struct_忽略组合中全部同名字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[EmbedSrc, EmbedDst]() + if err != nil { + return nil, err + } + return copier.Copy(&EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaoli", + Age: ekit.ToPtr[int](19), + Friends: []string{}, + }, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 20, + CNumber: complex(2, 2), + }, + }, IgnoreFields("Age")) + }, + wantDst: &EmbedDst{ + SimpleSrc: SimpleSrc{ + Name: "xiaoli", + Age: nil, + Friends: []string{}, + }, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 0, + CNumber: complex(2, 2), + }, + }, + }, + { + name: "组合_struct_忽略组合中同名结构体", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[EmbedSrc, EmbedDst]() + if err != nil { + return nil, err + } + return copier.Copy(&EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaoli", + Age: ekit.ToPtr[int](19), + Friends: []string{}, + }, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 20, + CNumber: complex(2, 2), + }, + }, IgnoreFields("SimpleSrc")) + }, + wantDst: &EmbedDst{ + SimpleSrc: SimpleSrc{}, + BasicSrc: &BasicSrc{ + Name: "xiaowang", + Age: 20, + CNumber: complex(2, 2), + }, + }, + }, + { + name: "复杂_Struct_忽略多层嵌套中全部同名字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ComplexSrc, ComplexDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ComplexSrc{ + Simple: SimpleSrc{ + Name: "xiaohong", + Age: ekit.ToPtr[int](18), + Friends: []string{"ha", "ha", "le"}, + }, + Embed: &EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaopeng", + Age: ekit.ToPtr[int](88), + Friends: []string{"la", "ha", "le"}, + }, + BasicSrc: &BasicSrc{ + Name: "wang", + Age: 22, + CNumber: complex(2, 1), + }, + }, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 22, + CNumber: complex(2, 1), + }, + }, IgnoreFields("Age")) + }, + wantDst: &ComplexDst{ + Simple: SimpleDst{ + Name: "xiaohong", + Age: nil, + Friends: []string{"ha", "ha", "le"}, + }, + Embed: &EmbedDst{ + SimpleSrc: SimpleSrc{ + Name: "xiaopeng", + Age: nil, + Friends: []string{"la", "ha", "le"}, + }, + BasicSrc: &BasicSrc{ + Name: "wang", + Age: 0, + CNumber: complex(2, 1), + }, + }, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 0, + CNumber: complex(2, 1), + }, + }, + }, + { + name: "复杂_Struct_忽略多层嵌套中的同名结构体", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ComplexSrc, ComplexDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ComplexSrc{ + Simple: SimpleSrc{ + Name: "xiaohong", + Age: ekit.ToPtr[int](18), + Friends: []string{"ha", "ha", "le"}, + }, + Embed: &EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaopeng", + Age: ekit.ToPtr[int](88), + Friends: []string{"la", "ha", "le"}, + }, + BasicSrc: &BasicSrc{ + Name: "wang", + Age: 22, + CNumber: complex(2, 1), + }, + }, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 22, + CNumber: complex(2, 1), + }, + }, IgnoreFields("SimpleSrc")) + }, + wantDst: &ComplexDst{ + Simple: SimpleDst{ + Name: "xiaohong", + Age: ekit.ToPtr[int](18), + Friends: []string{"ha", "ha", "le"}, + }, + Embed: &EmbedDst{ + SimpleSrc: SimpleSrc{}, + BasicSrc: &BasicSrc{ + Name: "wang", + Age: 22, + CNumber: complex(2, 1), + }, + }, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 22, + CNumber: complex(2, 1), + }, + }, + }, + { + name: "复杂_Struct_忽略多层嵌套中的整个结构体", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ComplexSrc, ComplexDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ComplexSrc{ + Simple: SimpleSrc{ + Name: "xiaohong", + Age: ekit.ToPtr[int](18), + Friends: []string{"ha", "ha", "le"}, + }, + Embed: &EmbedSrc{ + SimpleSrc: SimpleSrc{ + Name: "xiaopeng", + Age: ekit.ToPtr[int](88), + Friends: []string{"la", "ha", "le"}, + }, + BasicSrc: &BasicSrc{ + Name: "wang", + Age: 22, + CNumber: complex(2, 1), + }, + }, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 22, + CNumber: complex(2, 1), + }, + }, IgnoreFields("Embed")) + }, + wantDst: &ComplexDst{ + Simple: SimpleDst{ + Name: "xiaohong", + Age: ekit.ToPtr[int](18), + Friends: []string{"ha", "ha", "le"}, + }, + Embed: nil, + BasicSrc: BasicSrc{ + Name: "wang11", + Age: 22, + CNumber: complex(2, 1), + }, + }, + }, + { + name: "特殊类型_忽略结构体中的切片", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SpecialSrc, SpecialDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SpecialSrc{ + Arr: [3]float32{1, 2, 3}, + M: map[string]int{ + "ha": 1, + "o": 2, + }, + }, IgnoreFields("Arr")) + }, + wantDst: &SpecialDst{ + Arr: [3]float32{}, + M: map[string]int{ + "ha": 1, + "o": 2, + }, + }, + }, + { + name: "特殊类型_忽略结构体中的map", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SpecialSrc, SpecialDst]() + if err != nil { + return nil, err + } + return copier.Copy(&SpecialSrc{ + Arr: [3]float32{1, 2, 3}, + M: map[string]int{ + "ha": 1, + "o": 2, + }, + }, IgnoreFields("M")) + }, + wantDst: &SpecialDst{ + Arr: [3]float32{1, 2, 3}, + M: nil, + }, + }, + { + name: "dst_有额外字段_忽略一个字段_其他字段会被赋值", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[DiffSrc, DiffDst]() + if err != nil { + return nil, err + } + dst := &DiffDst{ + A: "66", + B: 1, + d: SimpleSrc{ + Name: "wodemingzi", + Age: ekit.ToPtr(int(10)), + }, + G: BasicSrc{ + Name: "nidemingzi", + Age: 23, + CNumber: complex(1, 2), + }, + } + err = copier.CopyTo(&DiffSrc{ + A: "xiaowang", + B: 100, + c: SimpleSrc{ + Name: "66", + Age: ekit.ToPtr[int](100), + }, + F: BasicSrc{ + Name: "good name", + Age: 200, + CNumber: complex(2, 2), + }, + }, dst, IgnoreFields("A")) + return dst, err + }, + wantDst: &DiffDst{ + A: "66", + B: 100, + d: SimpleSrc{ + Name: "wodemingzi", + Age: ekit.ToPtr(int(10)), + }, + G: BasicSrc{ + Name: "nidemingzi", + Age: 23, + CNumber: complex(1, 2), + }, + }, + }, + { + name: "dst_有额外字段_不会忽略dst的字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[DiffSrc, DiffDst]() + if err != nil { + return nil, err + } + dst := &DiffDst{ + A: "66", + B: 1, + d: SimpleSrc{ + Name: "wodemingzi", + Age: ekit.ToPtr(int(10)), + }, + G: BasicSrc{ + Name: "nidemingzi", + Age: 23, + CNumber: complex(1, 2), + }, + } + err = copier.CopyTo(&DiffSrc{ + A: "xiaowang", + B: 100, + c: SimpleSrc{ + Name: "66", + Age: ekit.ToPtr[int](100), + }, + F: BasicSrc{ + Name: "good name", + Age: 200, + CNumber: complex(2, 2), + }, + }, dst, IgnoreFields("G")) + return dst, err + }, + wantDst: &DiffDst{ + A: "xiaowang", + B: 100, + d: SimpleSrc{ + Name: "wodemingzi", + Age: ekit.ToPtr(int(10)), + }, + G: BasicSrc{ + Name: "nidemingzi", + Age: 23, + CNumber: complex(1, 2), + }, + }, + }, + { + name: "成员为结构体数组_不会忽略结构体中的字段", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ArraySrc, ArrayDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ArraySrc{ + A: []SimpleSrc{ + { + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, + { + Name: "小明", + Age: ekit.ToPtr[int](8), + Friends: []string{"Tom"}, + }, + }, + }, IgnoreFields("Age")) + }, + wantDst: &ArrayDst{ + A: []SimpleSrc{ + { + Name: "大明", + Age: ekit.ToPtr[int](18), + Friends: []string{"Tom", "Jerry"}, + }, + { + Name: "小明", + Age: ekit.ToPtr[int](8), + Friends: []string{"Tom"}, + }, + }, + }, + }, + { + name: "指定convert time2string,src为nil", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ConvSimpleSrc{}, ConvertField[time.Time, string]("BirthDay", converter.Time2String{Pattern: "2006-01-02 15:04:05"})) + }, + wantDst: &ConvSimpleDst{ + BirthDay: "0001-01-01 00:00:00", + }, + }, + { + name: "指定convert time2string", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ConvSimpleSrc{ + Name: "大明", + BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC), + Friends: []string{"Tom", "Jerry"}, + }, ConvertField[time.Time, string]("BirthDay", converter.Time2String{Pattern: "2006-01-02 15:04:05"})) + }, + wantDst: &ConvSimpleDst{ + Name: "大明", + BirthDay: "2023-07-26 09:15:22", + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "指定convert func, src为nil", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy( + &ConvSimpleSrc{}, + ConvertField[string, string]( + "Name", + converter.ConverterFunc[string, string](func(src string) (string, error) { + newS := fmt.Sprintf("%s plus", src) + return newS, nil + }), + ), + ConvertField[time.Time, string]( + "BirthDay", + converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return src.Format("2006-01-02 15:04:05"), nil + }), + ), + ConvertField[*int, *int]( + "Age", + converter.ConverterFunc[*int, *int](func(src *int) (*int, error) { + newS := *src + 1 + return &newS, nil + }), + ), + ConvertField[[]string, []string]( + "Friends", + converter.ConverterFunc[[]string, []string](func(src []string) ([]string, error) { + return []string{"Tom", "Jerry"}, nil + }), + ), + ) + }, + wantDst: &ConvSimpleDst{ + Name: " plus", + Age: nil, + BirthDay: "0001-01-01 00:00:00", + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "指定convert func, dst值为nil", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy( + &ConvSimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](11), + BirthDay: time.Now(), + Friends: []string{"Tom", "Jerry"}, + }, + ConvertField[string, string]( + "Name", + converter.ConverterFunc[string, string](func(src string) (string, error) { + return "", nil + }), + ), + ConvertField[time.Time, string]( + "BirthDay", + converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return "", nil + }), + ), + ConvertField[*int, *int]( + "Age", + converter.ConverterFunc[*int, *int](func(src *int) (*int, error) { + return nil, nil + }), + ), + ConvertField[[]string, []string]( + "Friends", + converter.ConverterFunc[[]string, []string](func(src []string) ([]string, error) { + return nil, nil + }), + ), + ) + }, + wantDst: &ConvSimpleDst{ + Name: "", + BirthDay: "", + Age: nil, + Friends: nil, + }, + }, + { + name: "指定convert func", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]() + if err != nil { + return nil, err + } + return copier.Copy( + &ConvSimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](15), + BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC), + Friends: []string{"Tom", "Jerry"}, + }, + ConvertField[string, string]( + "Name", + converter.ConverterFunc[string, string](func(src string) (string, error) { + newS := fmt.Sprintf("%s plus", src) + return newS, nil + }), + ), + ConvertField[time.Time, string]( + "BirthDay", + converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return src.Format("2006-01-02 15:04:05"), nil + }), + ), + ConvertField[*int, *int]( + "Age", + converter.ConverterFunc[*int, *int](func(src *int) (*int, error) { + newS := *src + 1 + return &newS, nil + }), + ), + ) + }, + wantDst: &ConvSimpleDst{ + Name: "大明 plus", + Age: ekit.ToPtr[int](16), + BirthDay: "2023-07-26 09:15:22", + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "指定返回特殊类型的convert func", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSpecialSrc, ConvSpecialDst]() + if err != nil { + return nil, err + } + return copier.Copy(&ConvSpecialSrc{ + Arr: [3]float32{1, 2, 3}, + M: map[string]int{"a": 4, "b": 5, "c": 6}, + Diff: map[string]int{"a1": 41, "b1": 51, "c1": 61}, + }, ConvertField[map[string]int, map[string]int]( + "M", + converter.ConverterFunc[map[string]int, map[string]int](func(src map[string]int) (map[string]int, error) { + newM := map[string]int{"a1": 41, "b1": 51, "c1": 61} + return newM, nil + })), + ConvertField[map[string]int, []int]( + "Diff", + converter.ConverterFunc[map[string]int, []int](func(src map[string]int) ([]int, error) { + newM := []int{1, 1, 1} + return newM, nil + })), + ) + }, + wantDst: &ConvSpecialDst{ + Arr: [3]float32{1, 2, 3}, + M: map[string]int{"a1": 41, "b1": 51, "c1": 61}, + Diff: []int{1, 1, 1}, + }, + }, + { + name: "创建时指定默认converter", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]( + ConvertField[time.Time, string]( + "BirthDay", + converter.Time2String{Pattern: "2006-01-02 15:04:05"}, + ), + ) + if err != nil { + return nil, err + } + return copier.Copy(&ConvSimpleSrc{ + Name: "大明", + BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC), + Friends: []string{"Tom", "Jerry"}, + }, ConvertField[string, string]("Name", converter.ConverterFunc[string, string](func(src string) (string, error) { + newS := fmt.Sprintf("%s plus", src) + return newS, nil + }))) + }, + wantDst: &ConvSimpleDst{ + Name: "大明 plus", + BirthDay: "2023-07-26 09:15:22", + Friends: []string{"Tom", "Jerry"}, + }, + }, + { + name: "创建时指定默认converter,convert同一个字段会覆盖", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]( + ConvertField[time.Time, string]( + "BirthDay", + converter.Time2String{Pattern: "2006-01-02 15:04:05"}, + ), + ) + if err != nil { + return nil, err + } + return copier.Copy(&ConvSimpleSrc{ + BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC), + }, ConvertField[time.Time, string]("BirthDay", converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return "1234567", nil + }))) + }, + wantDst: &ConvSimpleDst{ + BirthDay: "1234567", + }, + }, + { + name: "创建时指定默认converter,convert同一个字段会覆盖,覆盖后不影响默认配置", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]( + ConvertField[time.Time, string]( + "BirthDay", + converter.Time2String{Pattern: "2006-01-02 15:04:05"}, + ), + ) + if err != nil { + return nil, err + } + // 第一次执行Copy,函数中指定converter + _, err = copier.Copy( + &ConvSimpleSrc{BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC)}, + ConvertField[time.Time, string]( + "BirthDay", + converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return "1234567", nil + }))) + if err != nil { + return nil, err + } + // 第二次执行Copy,函数中不指定converter,走默认 + return copier.Copy(&ConvSimpleSrc{ + BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC), + }) + }, + wantDst: &ConvSimpleDst{ + BirthDay: "2023-07-26 09:15:22", + }, + }, + { + name: "创建时指定默认忽略字段,Copy()时指定的忽略字段不影响默认", + copyFunc: func() (any, error) { + copier, err := NewReflectCopier[SimpleSrc, SimpleDst](IgnoreFields("Age")) + if err != nil { + return nil, err + } + // 第一次执行Copy,函数中指定ignore字段 + _, err = copier.Copy(&SimpleSrc{ + Name: "大明", + Age: ekit.ToPtr[int](11), + }, IgnoreFields("Name")) + if err != nil { + return nil, err + } + // 第二次执行Copy,函数中不指定ignore字段,走默认 + return copier.Copy(&SimpleSrc{ + Name: "大明", + }) + }, + wantDst: &SimpleDst{ + Name: "大明", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, err := tc.copyFunc() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantDst, res) + }) + } +} + +func Test_Concurrency_Copy(t *testing.T) { + copier, err := NewReflectCopier[ConvSimpleSrc, ConvSimpleDst]( + ConvertField[time.Time, string]( + "BirthDay", + converter.Time2String{Pattern: "2006-01-02 15:04:05"}, + ), + ) + assert.Nil(t, err) + + var wg sync.WaitGroup + wg.Add(100) + for i := 0; i < 100; i++ { + go func(i int) { + defer wg.Done() + val := strconv.Itoa(i) + c, err := copier.Copy( + &ConvSimpleSrc{BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC)}, + ConvertField[time.Time, string]( + "BirthDay", + converter.ConverterFunc[time.Time, string](func(src time.Time) (string, error) { + return val, nil + }))) + assert.Nil(t, err) + assert.Equal(t, &ConvSimpleDst{BirthDay: val}, c) + }(i) + } + wg.Wait() +} + +type BasicSrc struct { + Name string + Age int + CNumber complex64 +} + +type BasicDst struct { + Name string + Age int + CNumber complex64 +} + +type SimpleSrc struct { + Name string + Age *int + Friends []string +} + +type SimpleDst struct { + Name string + Age *int + Friends []string +} + +type EmbedSrc struct { + SimpleSrc + *BasicSrc +} + +type EmbedDst struct { + SimpleSrc + *BasicSrc +} + +type ComplexSrc struct { + Simple SimpleSrc + Embed *EmbedSrc + BasicSrc +} + +type ComplexDst struct { + Simple SimpleDst + Embed *EmbedDst + BasicSrc +} + +type SpecialSrc struct { + Arr [3]float32 + M map[string]int +} + +type SpecialDst struct { + Arr [3]float32 + M map[string]int +} + +type InterfaceSrc interface { +} + +type InterfaceDst interface { +} + +type NotMatchSrc struct { + Simple SimpleSrc + Embed *EmbedSrc + BasicSrc + S struct { + A string + } +} + +type NotMatchDst struct { + Simple SimpleDst + Embed *EmbedDst + BasicSrc + S struct { + A int + } +} + +type MultiPtrSrc struct { + Name string + Age **int + Friends []string +} + +type MultiPtrDst struct { + Name string + Age **int + Friends []string +} + +type DiffSrc struct { + A string + B int + c SimpleSrc + F BasicSrc +} +type DiffDst struct { + A string + B int + d SimpleSrc + G BasicSrc +} + +type SimpleEmbedDst struct { + SimpleSrc +} + +type ArraySrc struct { + A []SimpleSrc +} + +type ArrayDst struct { + A []SimpleSrc +} + +type ArrayDst1 struct { + A []SimpleDst +} + +type MapSrc struct { + A map[string]SimpleSrc +} + +type MapDst struct { + A map[string]SimpleSrc +} + +type MapDst1 struct { + A map[string]SimpleDst +} + +type SpecialSrc1 struct { + A int +} + +type aliasInt int +type SpecialDst1 struct { + A aliasInt +} + +type aliasInt1 = int +type SpecialDst2 struct { + A aliasInt1 +} + +type ConvSimpleSrc struct { + Name string + Age *int + BirthDay time.Time + Friends []string +} + +type ConvSimpleDst struct { + Name string + Age *int + BirthDay string + Friends []string +} + +type ConvSpecialSrc struct { + Arr [3]float32 + M map[string]int + Diff map[string]int +} + +type ConvSpecialDst struct { + Arr [3]float32 + M map[string]int + Diff []int } func BenchmarkReflectCopier_Copy(b *testing.B) { diff --git a/go.mod b/go.mod index 4b69fc2b..80ea97db 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ecodeclub/ekit go 1.20 require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/mattn/go-sqlite3 v1.14.15 github.com/stretchr/testify v1.8.1 golang.org/x/sync v0.1.0 diff --git a/go.sum b/go.sum index b786e861..92a1b932 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/internal/slice/add.go b/internal/slice/add.go new file mode 100644 index 00000000..c368c50c --- /dev/null +++ b/internal/slice/add.go @@ -0,0 +1,35 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import "github.com/ecodeclub/ekit/internal/errs" + +func Add[T any](src []T, element T, index int) ([]T, error) { + length := len(src) + if index < 0 || index >= length { + return nil, errs.NewErrIndexOutOfRange(length, index) + } + + //先将src扩展一个元素 + var zeroValue T + src = append(src, zeroValue) + for i := len(src) - 1; i > index; i-- { + if i-1 >= 0 { + src[i] = src[i-1] + } + } + src[index] = element + return src, nil +} diff --git a/internal/slice/add_test.go b/internal/slice/add_test.go new file mode 100644 index 00000000..7236aece --- /dev/null +++ b/internal/slice/add_test.go @@ -0,0 +1,78 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "testing" + + "github.com/ecodeclub/ekit/internal/errs" + "github.com/stretchr/testify/assert" +) + +func TestAdd(t *testing.T) { + testCases := []struct { + name string + slice []int + addVal int + index int + wantSlice []int + wantErr error + }{ + { + name: "index 0", + slice: []int{123, 100}, + addVal: 233, + index: 0, + wantSlice: []int{233, 123, 100}, + }, + { + name: "index middle", + slice: []int{123, 124, 125}, + addVal: 233, + index: 1, + wantSlice: []int{123, 233, 124, 125}, + }, + { + name: "index out of range", + slice: []int{123, 100}, + index: 12, + wantErr: errs.NewErrIndexOutOfRange(2, 12), + }, + { + name: "index less than 0", + slice: []int{123, 100}, + index: -1, + wantErr: errs.NewErrIndexOutOfRange(2, -1), + }, + { + name: "index last", + slice: []int{123, 100, 101, 102, 102, 102}, + addVal: 233, + index: 5, + wantSlice: []int{123, 100, 101, 102, 102, 233, 102}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, err := Add(tc.slice, tc.addVal, tc.index) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, res) + }) + } +} diff --git a/internal/slice/delete.go b/internal/slice/delete.go index 75468b1a..8a011c76 100644 --- a/internal/slice/delete.go +++ b/internal/slice/delete.go @@ -22,14 +22,12 @@ func Delete[T any](src []T, index int) ([]T, T, error) { var zero T return nil, zero, errs.NewErrIndexOutOfRange(length, index) } - j := 0 res := src[index] - for i, v := range src { - if i != index { - src[j] = v - j++ - } + //从index位置开始,后面的元素依次往前挪1个位置 + for i := index; i+1 < length; i++ { + src[i] = src[i+1] } - src = src[:j] + //去掉最后一个重复元素 + src = src[:length-1] return src, res, nil } diff --git a/internal/tree/red_black_tree.go b/internal/tree/red_black_tree.go index 3a38ea63..cec0eac0 100644 --- a/internal/tree/red_black_tree.go +++ b/internal/tree/red_black_tree.go @@ -85,10 +85,14 @@ func (rb *RBTree[K, V]) Add(key K, value V) error { } // Delete 删除节点 -func (rb *RBTree[K, V]) Delete(key K) { +func (rb *RBTree[K, V]) Delete(key K) (V, bool) { if node := rb.findNode(key); node != nil { + value := node.value rb.deleteNode(node) + return value, true } + var v V + return v, false } // Find 查找节点 @@ -184,7 +188,8 @@ func (rb *RBTree[K, V]) addNode(node *rbNode[K, V]) error { // 着色旋转 // case1:当删除节点非空且为黑色时,会违反红黑树任何路径黑节点个数相同的约束,所以需要重新平衡 // case2:当删除红色节点时,不会破坏任何约束,所以不需要平衡 -func (rb *RBTree[K, V]) deleteNode(node *rbNode[K, V]) { +func (rb *RBTree[K, V]) deleteNode(tgt *rbNode[K, V]) { + node := tgt // node左右非空,取后继节点 if node.left != nil && node.right != nil { s := rb.findSuccessor(node) diff --git a/mapx/builtin_map.go b/mapx/builtin_map.go new file mode 100644 index 00000000..a5621d1e --- /dev/null +++ b/mapx/builtin_map.go @@ -0,0 +1,52 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +// builtinMap 是对 map 的二次封装 +// 主要用于各种装饰器模式中被装饰的那个 +type builtinMap[K comparable, V any] struct { + data map[K]V +} + +func (b *builtinMap[K, V]) Put(key K, val V) error { + b.data[key] = val + return nil +} + +func (b *builtinMap[K, V]) Get(key K) (V, bool) { + val, ok := b.data[key] + return val, ok +} + +func (b *builtinMap[K, V]) Delete(k K) (V, bool) { + v, ok := b.data[k] + delete(b.data, k) + return v, ok +} + +// Keys 返回的 key 是随机的。即便对于同一个实例,调用两次,得到的结果都可能不同。 +func (b *builtinMap[K, V]) Keys() []K { + return Keys[K, V](b.data) +} + +func (b *builtinMap[K, V]) Values() []V { + return Values[K, V](b.data) +} + +func newBuiltinMap[K comparable, V any](capacity int) *builtinMap[K, V] { + return &builtinMap[K, V]{ + data: make(map[K]V, capacity), + } +} diff --git a/mapx/builtin_map_test.go b/mapx/builtin_map_test.go new file mode 100644 index 00000000..63572e66 --- /dev/null +++ b/mapx/builtin_map_test.go @@ -0,0 +1,215 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuiltinMap_Delete(t *testing.T) { + testCases := []struct { + name string + data map[string]string + + key string + + wantVal string + wantDeleted bool + }{ + { + name: "deleted", + data: map[string]string{ + "key1": "val1", + }, + key: "key1", + + wantVal: "val1", + wantDeleted: true, + }, + { + name: "key not exist", + data: map[string]string{ + "key1": "val1", + }, + key: "key2", + }, + { + name: "nil map", + key: "key2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := builtinMapOf[string, string](tc.data) + val, ok := m.Delete(tc.key) + assert.Equal(t, tc.wantDeleted, ok) + assert.Equal(t, tc.wantVal, val) + _, ok = m.data[tc.key] + assert.False(t, ok) + }) + } +} + +func TestBuiltinMap_Get(t *testing.T) { + testCases := []struct { + name string + data map[string]string + + key string + + wantVal string + found bool + }{ + { + name: "found", + data: map[string]string{ + "key1": "val1", + }, + key: "key1", + + wantVal: "val1", + found: true, + }, + { + name: "key not exist", + data: map[string]string{ + "key1": "val1", + }, + key: "key2", + }, + { + name: "nil map", + key: "key2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := builtinMapOf[string, string](tc.data) + val, ok := m.Get(tc.key) + assert.Equal(t, tc.found, ok) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestBuiltinMap_Put(t *testing.T) { + testCases := []struct { + name string + + key string + val string + cap int + + wantErr error + }{ + { + name: "puted", + key: "key1", + val: "val1", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := newBuiltinMap[string, string](tc.cap) + err := m.Put(tc.key, tc.val) + assert.Equal(t, tc.wantErr, err) + v, ok := m.data[tc.key] + assert.True(t, ok) + assert.Equal(t, tc.val, v) + }) + } +} + +func TestBuiltinMap_Keys(t *testing.T) { + testCases := []struct { + name string + data map[string]string + + wantKeys []string + }{ + { + name: "got keys", + data: map[string]string{ + "key1": "val1", + "key2": "val1", + "key3": "val1", + "key4": "val1", + }, + wantKeys: []string{"key1", "key2", "key3", "key4"}, + }, + { + name: "empty map", + data: map[string]string{}, + wantKeys: []string{}, + }, + { + name: "nil map", + wantKeys: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := builtinMapOf[string, string](tc.data) + keys := m.Keys() + assert.ElementsMatch(t, tc.wantKeys, keys) + }) + } +} + +func TestBuiltinMap_Values(t *testing.T) { + testCases := []struct { + name string + data map[string]string + + wantValues []string + }{ + { + name: "got values", + data: map[string]string{ + "key1": "val1", + "key2": "val2", + "key3": "val3", + "key4": "val4", + }, + wantValues: []string{"val1", "val2", "val3", "val4"}, + }, + { + name: "empty map", + data: map[string]string{}, + wantValues: []string{}, + }, + { + name: "nil map", + wantValues: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := builtinMapOf[string, string](tc.data) + vals := m.Values() + assert.ElementsMatch(t, tc.wantValues, vals) + }) + } +} + +func builtinMapOf[K comparable, V any](data map[K]V) *builtinMap[K, V] { + return &builtinMap[K, V]{data: data} +} diff --git a/mapx/hashmap.go b/mapx/hashmap.go index eec9fda5..f2a7a6a5 100644 --- a/mapx/hashmap.go +++ b/mapx/hashmap.go @@ -17,12 +17,12 @@ package mapx import "github.com/ecodeclub/ekit/syncx" type node[T Hashable, ValType any] struct { - key Hashable + key T value ValType next *node[T, ValType] } -func (m *HashMap[T, ValType]) newNode(key Hashable, val ValType) *node[T, ValType] { +func (m *HashMap[T, ValType]) newNode(key T, val ValType) *node[T, ValType] { newNode := m.nodePool.Get() newNode.value = val newNode.key = key @@ -83,8 +83,8 @@ func (m *HashMap[T, ValType]) Get(key T) (ValType, bool) { // Keys 返回 Hashmap 里面的所有的 key。 // 注意:key 的顺序是随机的。 -func (m *HashMap[T, ValType]) Keys() []Hashable { - res := make([]Hashable, 0) +func (m *HashMap[T, ValType]) Keys() []T { + res := make([]T, 0) for _, bucketNode := range m.hashmap { curNode := bucketNode for curNode != nil { @@ -118,11 +118,6 @@ func NewHashMap[T Hashable, ValType any](size int) *HashMap[T, ValType] { } } -type mapi[T any, ValType any] interface { - Put(key T, val ValType) error - Get(key T) (ValType, bool) -} - var _ mapi[Hashable, any] = (*HashMap[Hashable, any])(nil) // Delete 第一个返回值为删除key的值,第二个是hashmap是否真的有这个key diff --git a/mapx/hashmap_test.go b/mapx/hashmap_test.go index edd23953..51161644 100644 --- a/mapx/hashmap_test.go +++ b/mapx/hashmap_test.go @@ -22,6 +22,9 @@ import ( "github.com/stretchr/testify/assert" ) +// 借助 testData 来验证一下 HashMap 实现了 mapi 接口 +var _ mapi[testData, int] = &HashMap[testData, int]{} + func TestHashMap(t *testing.T) { testKV := []struct { key testData @@ -541,5 +544,4 @@ func BenchmarkMyHashMap(b *testing.B) { _ = m[uint64(i)] } }) - } diff --git a/mapx/linkedmap.go b/mapx/linkedmap.go new file mode 100644 index 00000000..127f32af --- /dev/null +++ b/mapx/linkedmap.go @@ -0,0 +1,112 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import "github.com/ecodeclub/ekit" + +type LinkedMap[K any, V any] struct { + m mapi[K, *linkedKV[K, V]] + head, tail *linkedKV[K, V] + length int +} + +type linkedKV[K any, V any] struct { + key K + value V + prev, next *linkedKV[K, V] +} + +func NewLinkedHashMap[K Hashable, V any](size int) *LinkedMap[K, V] { + hashmap := NewHashMap[K, *linkedKV[K, V]](size) + head := &linkedKV[K, V]{} + tail := &linkedKV[K, V]{next: head, prev: head} + head.prev, head.next = tail, tail + return &LinkedMap[K, V]{ + m: hashmap, + head: head, + tail: tail, + } +} + +func NewLinkedTreeMap[K any, V any](comparator ekit.Comparator[K]) (*LinkedMap[K, V], error) { + treeMap, err := NewTreeMap[K, *linkedKV[K, V]](comparator) + if err != nil { + return nil, err + } + head := &linkedKV[K, V]{} + tail := &linkedKV[K, V]{next: head, prev: head} + head.prev, head.next = tail, tail + return &LinkedMap[K, V]{ + m: treeMap, + head: head, + tail: tail, + }, nil +} + +func (l *LinkedMap[K, V]) Put(key K, val V) error { + if lk, ok := l.m.Get(key); ok { + lk.value = val + return nil + } + lk := &linkedKV[K, V]{ + key: key, + value: val, + prev: l.tail.prev, + next: l.tail, + } + if err := l.m.Put(key, lk); err != nil { + return err + } + lk.prev.next, lk.next.prev = lk, lk + l.length++ + return nil +} + +func (l *LinkedMap[K, V]) Get(key K) (V, bool) { + if lk, ok := l.m.Get(key); ok { + return lk.value, ok + } + var v V + return v, false +} + +func (l *LinkedMap[K, V]) Delete(key K) (V, bool) { + if lk, ok := l.m.Delete(key); ok { + lk.prev.next = lk.next + lk.next.prev = lk.prev + l.length-- + return lk.value, ok + } + var v V + return v, false +} + +func (l *LinkedMap[K, V]) Keys() []K { + keys := make([]K, 0, l.length) + for cur := l.head.next; cur != l.tail; { + keys = append(keys, cur.key) + cur = cur.next + } + return keys +} + +func (l *LinkedMap[K, V]) Values() []V { + values := make([]V, 0, l.length) + for cur := l.head.next; cur != l.tail; { + values = append(values, cur.value) + cur = cur.next + } + return values +} diff --git a/mapx/linkedmap_test.go b/mapx/linkedmap_test.go new file mode 100644 index 00000000..a0060f3a --- /dev/null +++ b/mapx/linkedmap_test.go @@ -0,0 +1,469 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "errors" + "testing" + + "github.com/ecodeclub/ekit" + "github.com/stretchr/testify/assert" +) + +var ( + fakeErr = errors.New("fakeMap: put error") +) + +type fakeMap[K any, V any] struct { + *LinkedMap[K, V] + count int + activeFirstErr bool +} + +func (f *fakeMap[K, V]) Put(key K, val V) error { + f.count++ + if f.activeFirstErr { + f.activeFirstErr = false + return fakeErr + } + if f.count == 3 { + return fakeErr + } + if f.count == 5 { + return fakeErr + } + return f.LinkedMap.Put(key, val) +} + +func newLinkedFakeMap[K any, V any](activeFirstErr bool, comparator ekit.Comparator[K]) (*LinkedMap[K, V], error) { + treeMap, err := NewLinkedTreeMap[K, *linkedKV[K, V]](comparator) + if err != nil { + return nil, err + } + fm := &fakeMap[K, *linkedKV[K, V]]{LinkedMap: treeMap, activeFirstErr: activeFirstErr} + head := &linkedKV[K, V]{} + tail := &linkedKV[K, V]{next: head, prev: head} + head.prev, head.next = tail, tail + return &LinkedMap[K, V]{ + m: fm, + head: head, + tail: tail, + }, nil +} + +func TestLinkedMap_NewLinkedHashMap(t *testing.T) { + testCases := []struct { + name string + size int + }{ + { + name: "negative size", + size: -1, + }, + { + name: "zero size", + size: 0, + }, + { + name: "Positive size", + size: 1, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + multiMap := NewLinkedHashMap[testData, int](tt.size) + assert.NotNil(t, multiMap) + assert.Equal(t, multiMap.Keys(), []testData{}) + assert.Equal(t, multiMap.Values(), []int{}) + }) + } +} + +func TestLinkedMap_NewLinkedTreeMap(t *testing.T) { + testCases := []struct { + name string + comparator ekit.Comparator[int] + + wantErr error + }{ + { + name: "no error", + comparator: ekit.ComparatorRealNumber[int], + + wantErr: nil, + }, + { + name: "match errLinkedTreeMapComparatorIsNull error", + comparator: nil, + + wantErr: errTreeMapComparatorIsNull, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + linkedTreeMap, err := NewLinkedTreeMap[int, int](tt.comparator) + assert.Equal(t, tt.wantErr, err) + if err != nil { + assert.Nil(t, linkedTreeMap) + } else { + assert.NotNil(t, linkedTreeMap) + assert.Equal(t, linkedTreeMap.Keys(), []int{}) + assert.Equal(t, linkedTreeMap.Values(), []int{}) + } + }) + } +} + +func TestLinkedMap_Put(t *testing.T) { + testCases := []struct { + name string + linkedMap func(t *testing.T) *LinkedMap[int, int] + keys []int + values []int + + wantKeys []int + wantValues []int + wantErrs []error + }{ + { + name: "put single key", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + keys: []int{1}, + values: []int{1}, + + wantKeys: []int{1}, + wantValues: []int{1}, + wantErrs: []error{nil}, + }, + { + name: "put multiple keys", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + keys: []int{1, 2, 3, 4}, + values: []int{1, 2, 3, 4}, + + wantKeys: []int{1, 2, 3, 4}, + wantValues: []int{1, 2, 3, 4}, + wantErrs: []error{nil, nil, nil, nil}, + }, + { + name: "change value of single key", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + keys: []int{1, 1, 2, 3}, + values: []int{1, 11, 2, 3}, + + wantKeys: []int{1, 2, 3}, + wantValues: []int{11, 2, 3}, + wantErrs: []error{nil, nil, nil, nil}, + }, + { + name: "change value of multiple keys", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + keys: []int{1, 1, 2, 2, 3}, + values: []int{1, 11, 2, 22, 3}, + + wantKeys: []int{1, 2, 3}, + wantValues: []int{11, 22, 3}, + wantErrs: []error{nil, nil, nil, nil, nil}, + }, + { + name: "get error when put single key", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedFakeMap, err := newLinkedFakeMap[int, int](true, ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedFakeMap + }, + keys: []int{1}, + values: []int{1}, + + wantKeys: []int{}, + wantValues: []int{}, + wantErrs: []error{fakeErr}, + }, + { + name: "get multiple errors when put multiple keys", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedFakeMap, err := newLinkedFakeMap[int, int](true, ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedFakeMap + }, + keys: []int{1, 2, 3, 4, 5}, + values: []int{1, 2, 3, 4, 5}, + + wantKeys: []int{2, 4}, + wantValues: []int{2, 4}, + wantErrs: []error{fakeErr, nil, fakeErr, nil, fakeErr}, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + errs := make([]error, 0) + linkedMap := tt.linkedMap(t) + for i := range tt.keys { + err := linkedMap.Put(tt.keys[i], tt.values[i]) + errs = append(errs, err) + } + + for i := range tt.wantKeys { + v, b := linkedMap.Get(tt.wantKeys[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + + assert.Equal(t, tt.wantKeys, linkedMap.Keys()) + assert.Equal(t, tt.wantValues, linkedMap.Values()) + assert.Equal(t, tt.wantErrs, errs) + }) + } +} + +func TestLinkedMap_Get(t *testing.T) { + testCases := []struct { + name string + linkedMap func(t *testing.T) *LinkedMap[int, int] + key int + + wantValue int + wantBool bool + }{ + { + name: "can not find value in empty linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + key: 1, + + wantValue: 0, + wantBool: false, + }, + { + name: "can not find value in linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + err = linkedTreeMap.Put(1, 1) + assert.NoError(t, err) + return linkedTreeMap + }, + key: 2, + + wantValue: 0, + wantBool: false, + }, + { + name: "find value in linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + err = linkedTreeMap.Put(1, 1) + assert.NoError(t, err) + return linkedTreeMap + }, + key: 1, + + wantValue: 1, + wantBool: true, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + v, b := tt.linkedMap(t).Get(tt.key) + assert.Equal(t, tt.wantBool, b) + assert.Equal(t, tt.wantValue, v) + }) + } +} + +func TestLinkedMap_Delete(t *testing.T) { + testCases := []struct { + name string + linkedMap func(t *testing.T) *LinkedMap[int, int] + + key int + + delValue int + wantBool bool + wantKeys []int + wantValues []int + }{ + { + name: "delete key in empty linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + return linkedTreeMap + }, + + key: 1, + + delValue: 0, + wantBool: false, + wantKeys: []int{}, + wantValues: []int{}, + }, + { + name: "delete unknown key in not empty linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + return linkedTreeMap + }, + + key: 2, + + delValue: 0, + wantBool: false, + wantKeys: []int{1}, + wantValues: []int{1}, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + linkedMap := tt.linkedMap(t) + v, b := linkedMap.Delete(tt.key) + assert.Equal(t, tt.wantBool, b) + assert.Equal(t, tt.delValue, v) + + assert.Equal(t, tt.wantKeys, linkedMap.Keys()) + assert.Equal(t, tt.wantValues, linkedMap.Values()) + }) + } +} + +func TestLinkedMap_PutAndDelete(t *testing.T) { + testCases := []struct { + name string + linkedMap func(t *testing.T) *LinkedMap[int, int] + + wantKeys []int + wantValues []int + }{ + { + name: "put k1 → delete k1", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + v, ok := linkedTreeMap.Delete(1) + assert.Equal(t, 1, v) + assert.Equal(t, true, ok) + return linkedTreeMap + }, + + wantKeys: []int{}, + wantValues: []int{}, + }, + { + name: "put k1 → put k2 → delete k1", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + assert.NoError(t, linkedTreeMap.Put(2, 2)) + v, ok := linkedTreeMap.Delete(1) + assert.Equal(t, 1, v) + assert.Equal(t, true, ok) + return linkedTreeMap + }, + + wantKeys: []int{2}, + wantValues: []int{2}, + }, + { + name: "put k1 → put k2 → delete k2", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + assert.NoError(t, linkedTreeMap.Put(2, 2)) + v, ok := linkedTreeMap.Delete(2) + assert.Equal(t, 2, v) + assert.Equal(t, true, ok) + return linkedTreeMap + }, + + wantKeys: []int{1}, + wantValues: []int{1}, + }, + { + name: "put k1 → delete k1 → put k2 → put k3", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + v, ok := linkedTreeMap.Delete(1) + assert.Equal(t, 1, v) + assert.Equal(t, true, ok) + assert.NoError(t, linkedTreeMap.Put(2, 2)) + assert.NoError(t, linkedTreeMap.Put(3, 3)) + + return linkedTreeMap + }, + + wantKeys: []int{2, 3}, + wantValues: []int{2, 3}, + }, + { + name: "put k1 → put k2 → put k3 → delete k2", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, err := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, err) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + assert.NoError(t, linkedTreeMap.Put(2, 2)) + assert.NoError(t, linkedTreeMap.Put(3, 3)) + v, ok := linkedTreeMap.Delete(2) + assert.Equal(t, 2, v) + assert.Equal(t, true, ok) + + return linkedTreeMap + }, + + wantKeys: []int{1, 3}, + wantValues: []int{1, 3}, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + linkedMap := tt.linkedMap(t) + for i := range tt.wantKeys { + v, b := linkedMap.Get(tt.wantKeys[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + assert.Equal(t, tt.wantKeys, linkedMap.Keys()) + assert.Equal(t, tt.wantValues, linkedMap.Values()) + }) + } +} diff --git a/mapx/multi_map.go b/mapx/multi_map.go new file mode 100644 index 00000000..f6497164 --- /dev/null +++ b/mapx/multi_map.go @@ -0,0 +1,95 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "github.com/ecodeclub/ekit" +) + +// MultiMap 多映射的 Map +// 它可以将一个键映射到多个值上 +type MultiMap[K any, V any] struct { + m mapi[K, []V] +} + +// NewMultiTreeMap 创建一个基于 TreeMap 的 MultiMap +// 注意: +// - comparator 不能为 nil +func NewMultiTreeMap[K any, V any](comparator ekit.Comparator[K]) (*MultiMap[K, V], error) { + treeMap, err := NewTreeMap[K, []V](comparator) + if err != nil { + return nil, err + } + return &MultiMap[K, V]{ + m: treeMap, + }, nil +} + +// NewMultiHashMap 创建一个基于 HashMap 的 MultiMap +func NewMultiHashMap[K Hashable, V any](size int) *MultiMap[K, V] { + var m mapi[K, []V] = NewHashMap[K, []V](size) + return &MultiMap[K, V]{ + m: m, + } +} + +func NewMultiBuiltinMap[K comparable, V any](size int) *MultiMap[K, V] { + var m mapi[K, []V] = newBuiltinMap[K, []V](size) + return &MultiMap[K, V]{ + m: m, + } +} + +// Put 在 MultiMap 中添加键值对或向已有键 k 的值追加数据 +func (m *MultiMap[K, V]) Put(k K, v V) error { + return m.PutMany(k, v) +} + +// PutMany 在 MultiMap 中添加键值对或向已有键 k 的值追加多个数据 +func (m *MultiMap[K, V]) PutMany(k K, v ...V) error { + val, _ := m.Get(k) + val = append(val, v...) + return m.m.Put(k, val) +} + +// Get 从 MultiMap 中获取已有键 k 的值 +// 如果键 k 不存在,则返回的 bool 值为 false +// 返回的切片是一个副本,你对该切片的修改不会影响原本的数据。 +func (m *MultiMap[K, V]) Get(k K) ([]V, bool) { + if v, ok := m.m.Get(k); ok { + return append([]V{}, v...), ok + } + return nil, false +} + +// Delete 从 MultiMap 中删除指定的键 k +func (m *MultiMap[K, V]) Delete(k K) ([]V, bool) { + return m.m.Delete(k) +} + +// Keys 返回 MultiMap 所有的键 +func (m *MultiMap[K, V]) Keys() []K { + return m.m.Keys() +} + +// Values 返回 MultiMap 所有的值 +func (m *MultiMap[K, V]) Values() [][]V { + values := m.m.Values() + copyValues := make([][]V, 0, len(values)) + for i := range values { + copyValues = append(copyValues, append([]V{}, values[i]...)) + } + return copyValues +} diff --git a/mapx/multi_map_test.go b/mapx/multi_map_test.go new file mode 100644 index 00000000..0875763a --- /dev/null +++ b/mapx/multi_map_test.go @@ -0,0 +1,596 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "testing" + + "github.com/ecodeclub/ekit" + "github.com/stretchr/testify/assert" +) + +func getMultiTreeMap() *MultiMap[int, int] { + multiTreeMap, _ := NewMultiTreeMap[int, int](ekit.ComparatorRealNumber[int]) + return multiTreeMap +} +func getMultiHashMap() *MultiMap[testData, int] { + return NewMultiHashMap[testData, int](10) +} + +func TestMultiMap_NewMultiHashMap(t *testing.T) { + testCases := []struct { + name string + size int + }{ + { + name: "negative size", + size: -1, + }, + { + name: "zero size", + size: 0, + }, + { + name: "Positive size", + size: 1, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + multiMap := NewMultiHashMap[testData, int](tt.size) + assert.NotNil(t, multiMap) + }) + } +} + +func TestMultiMap_NewMultiTreeMap(t *testing.T) { + testCases := []struct { + name string + comparator ekit.Comparator[int] + + wantErr error + }{ + { + name: "no error", + comparator: ekit.ComparatorRealNumber[int], + + wantErr: nil, + }, + { + name: "match errMultiMapComparatorIsNull error", + comparator: nil, + + wantErr: errTreeMapComparatorIsNull, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + multiMap, err := NewMultiTreeMap[int, int](tt.comparator) + assert.Equal(t, tt.wantErr, err) + if err != nil { + assert.Nil(t, multiMap) + } else { + assert.NotNil(t, multiMap) + } + }) + } +} + +func TestNewMultiBuiltinMap(t *testing.T) { + testCases := []struct { + name string + size int + }{ + { + name: "negative size", + size: -1, + }, + { + name: "zero size", + size: 0, + }, + { + name: "Positive size", + size: 1, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + multiMap := NewMultiBuiltinMap[testData, int](tt.size) + assert.NotNil(t, multiMap) + }) + } +} + +func TestMultiMap_Keys(t *testing.T) { + testCases := []struct { + name string + multiTreeMap *MultiMap[int, int] + multiHashMap *MultiMap[testData, int] + + wantMultiTreeMapKeys []int + wantMultiHashMapKeys []testData + }{ + { + name: "empty", + multiTreeMap: func() *MultiMap[int, int] { + return getMultiTreeMap() + }(), + multiHashMap: func() *MultiMap[testData, int] { + return getMultiHashMap() + }(), + + wantMultiTreeMapKeys: []int{}, + wantMultiHashMapKeys: []testData{}, + }, + { + name: "single one", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + return multiHashMap + }(), + + wantMultiTreeMapKeys: []int{1}, + wantMultiHashMapKeys: []testData{{id: 1}}, + }, + { + name: "multiple", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(2, 2) + _ = multiTreeMap.Put(3, 3) + _ = multiTreeMap.Put(4, 4) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 2}, 2) + _ = multiHashMap.Put(testData{id: 3}, 3) + _ = multiHashMap.Put(testData{id: 4}, 4) + return multiHashMap + }(), + + wantMultiTreeMapKeys: []int{1, 2, 3, 4}, + wantMultiHashMapKeys: []testData{ + {id: 1}, + {id: 2}, + {id: 3}, + {id: 4}, + }, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + assert.ElementsMatch(t, tt.wantMultiTreeMapKeys, tt.multiTreeMap.Keys()) + }) + + t.Run("MultiHashMap", func(t *testing.T) { + assert.ElementsMatch(t, tt.wantMultiHashMapKeys, tt.multiHashMap.Keys()) + }) + } +} + +func TestMultiMap_Values(t *testing.T) { + testCases := []struct { + name string + multiTreeMap *MultiMap[int, int] + multiHashMap *MultiMap[testData, int] + + wantValues [][]int + }{ + { + name: "empty", + multiTreeMap: func() *MultiMap[int, int] { + return getMultiTreeMap() + }(), + multiHashMap: func() *MultiMap[testData, int] { + return getMultiHashMap() + }(), + + wantValues: [][]int{}, + }, + { + name: "single one", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + return multiHashMap + }(), + + wantValues: [][]int{{1}}, + }, + { + name: "multiple", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(2, 2) + _ = multiTreeMap.Put(3, 3) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 2}, 2) + _ = multiHashMap.Put(testData{id: 3}, 3) + return multiHashMap + }(), + + wantValues: [][]int{{1}, {2}, {3}}, + }, + } + t.Run("MultiTreeMap", func(t *testing.T) { + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + assert.ElementsMatch(t, tt.wantValues, tt.multiTreeMap.Values()) + }) + } + }) + t.Run("MultiHashMap", func(t *testing.T) { + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + assert.ElementsMatch(t, tt.wantValues, tt.multiHashMap.Values()) + }) + } + }) + +} + +func TestMultiMap_Put(t *testing.T) { + testCases := []struct { + name string + keys []int + values []int + + wantKeys []int + wantValues [][]int + wantErr error + }{ + { + name: "put simple one", + keys: []int{1}, + values: []int{1}, + + wantKeys: []int{1}, + wantValues: [][]int{{1}}, + wantErr: nil, + }, + { + name: "put multiple", + keys: []int{1, 2, 3, 4}, + values: []int{1, 2, 3, 4}, + + wantKeys: []int{1, 2, 3, 4}, + wantValues: [][]int{{1}, {2}, {3}, {4}}, + wantErr: nil, + }, + { + name: "the key include the same", + keys: []int{1, 2, 1, 4}, + values: []int{1, 2, 3, 4}, + + wantKeys: []int{1, 2, 4}, + wantValues: [][]int{ + {1, 3}, + {2}, + {4}, + }, + wantErr: nil, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + multiTreeMap, _ := NewMultiTreeMap[int, int](ekit.ComparatorRealNumber[int]) + for i := range tt.keys { + err := multiTreeMap.Put(tt.keys[i], tt.values[i]) + assert.Equal(t, tt.wantErr, err) + } + + for i := range tt.wantKeys { + v, b := multiTreeMap.Get(tt.wantKeys[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + }) + + t.Run("MultiHashMap", func(t *testing.T) { + multiHashMap := NewMultiHashMap[testData, int](10) + for i := range tt.keys { + err := multiHashMap.Put(testData{id: tt.keys[i]}, tt.values[i]) + assert.Equal(t, tt.wantErr, err) + } + + for i := range tt.wantKeys { + v, b := multiHashMap.Get(testData{id: tt.wantKeys[i]}) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + }) + } +} + +func TestMultiMap_Get(t *testing.T) { + testCases := []struct { + name string + multiTreeMap *MultiMap[int, int] + multiHashMap *MultiMap[testData, int] + key int + + wantValue []int + wantBool bool + }{ + { + name: "not found (nil) in empty data", + multiTreeMap: func() *MultiMap[int, int] { + return getMultiTreeMap() + }(), + multiHashMap: func() *MultiMap[testData, int] { + return getMultiHashMap() + }(), + key: 1, + + wantValue: nil, + wantBool: false, + }, + { + name: "not found (nil) in data", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(2, 2) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 2}, 2) + return multiHashMap + }(), + key: 3, + + wantValue: nil, + wantBool: false, + }, + { + name: "found data", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + return multiHashMap + }(), + key: 1, + + wantValue: []int{1}, + wantBool: true, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + v, b := tt.multiTreeMap.Get(tt.key) + assert.Equal(t, tt.wantBool, b) + assert.ElementsMatch(t, tt.wantValue, v) + }) + + t.Run("MultiHashMap", func(t *testing.T) { + v2, b2 := tt.multiHashMap.Get(testData{id: tt.key}) + assert.Equal(t, tt.wantBool, b2) + assert.ElementsMatch(t, tt.wantValue, v2) + }) + } +} + +func TestMultiMap_Delete(t *testing.T) { + testCases := []struct { + name string + multiTreeMap *MultiMap[int, int] + multiHashMap *MultiMap[testData, int] + + key int + + delValue []int + wantBool bool + }{ + { + name: "not found in empty data", + multiTreeMap: func() *MultiMap[int, int] { + return getMultiTreeMap() + }(), + multiHashMap: func() *MultiMap[testData, int] { + return getMultiHashMap() + }(), + + key: 1, + + delValue: nil, + wantBool: false, + }, + { + name: "not found in data", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + return multiHashMap + }(), + + key: 2, + + delValue: nil, + wantBool: false, + }, + { + name: "found and deleted", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(2, 2) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 2}, 2) + return multiHashMap + }(), + key: 1, + + delValue: []int{1}, + wantBool: true, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + v, b := tt.multiTreeMap.Delete(tt.key) + assert.Equal(t, tt.wantBool, b) + assert.ElementsMatch(t, tt.delValue, v) + }) + t.Run("MultiHashMap", func(t *testing.T) { + v, b := tt.multiHashMap.Delete(testData{id: tt.key}) + assert.Equal(t, tt.wantBool, b) + assert.ElementsMatch(t, tt.delValue, v) + }) + } +} + +func TestMultiMap_PutMany(t *testing.T) { + testCases := []struct { + name string + keys []int + values [][]int + + wantKeys []int + wantValues [][]int + wantErr error + }{ + { + name: "one to one", + keys: []int{1}, + values: [][]int{{1}}, + + wantKeys: []int{1}, + wantValues: [][]int{{1}}, + wantErr: nil, + }, + { + name: "many [one to one]", + keys: []int{1, 2, 3}, + values: [][]int{{1}, {2}, {3}}, + + wantKeys: []int{1, 2, 3}, + wantValues: [][]int{{1}, {2}, {3}}, + wantErr: nil, + }, + { + name: "one to many", + keys: []int{1}, + values: [][]int{{1, 2, 3}}, + + wantKeys: []int{1}, + wantValues: [][]int{ + {1, 2, 3}, + }, + wantErr: nil, + }, + { + name: "many [one to many]", + keys: []int{1, 2, 3}, + values: [][]int{{1, 2, 3}, {1, 2, 3}, {1, 2, 3}}, + + wantKeys: []int{1, 2, 3}, + wantValues: [][]int{ + {1, 2, 3}, + {1, 2, 3}, + {1, 2, 3}, + }, + wantErr: nil, + }, + { + name: "the key include the same for append one", + keys: []int{1, 1}, + values: [][]int{{1, 2, 3, 4, 5}, {6}}, + + wantKeys: []int{1}, + wantValues: [][]int{ + {1, 2, 3, 4, 5, 6}, + }, + wantErr: nil, + }, + { + name: "the key include the same for append many", + keys: []int{1, 1}, + values: [][]int{{1}, {2, 3, 4, 5, 6}}, + + wantKeys: []int{1}, + wantValues: [][]int{ + {1, 2, 3, 4, 5, 6}, + }, + wantErr: nil, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + multiTreeMap, _ := NewMultiTreeMap[int, int](ekit.ComparatorRealNumber[int]) + for i := range tt.keys { + err := multiTreeMap.PutMany(tt.keys[i], tt.values[i]...) + assert.Equal(t, tt.wantErr, err) + } + + for i := range tt.wantKeys { + v, b := multiTreeMap.Get(tt.wantKeys[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + }) + + t.Run("MultiHashMap", func(t *testing.T) { + multiHashMap := NewMultiHashMap[testData, int](10) + for i := range tt.keys { + err := multiHashMap.PutMany(testData{id: tt.keys[i]}, tt.values[i]...) + assert.Equal(t, tt.wantErr, err) + } + + for i := range tt.wantKeys { + v, b := multiHashMap.Get(testData{id: tt.wantKeys[i]}) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantValues[i], v) + } + }) + } +} diff --git a/mapx/treemap.go b/mapx/treemap.go index 24b0aa79..719d696d 100644 --- a/mapx/treemap.go +++ b/mapx/treemap.go @@ -27,7 +27,7 @@ var ( // TreeMap 是基于红黑树实现的Map type TreeMap[K any, V any] struct { - *tree.RBTree[K, V] + tree *tree.RBTree[K, V] } // NewTreeMapWithMap TreeMap构造方法 @@ -48,7 +48,7 @@ func NewTreeMap[K any, V any](compare ekit.Comparator[K]) (*TreeMap[K, V], error return nil, errTreeMapComparatorIsNull } return &TreeMap[K, V]{ - RBTree: tree.NewRBTree[K, V](compare), + tree: tree.NewRBTree[K, V](compare), }, nil } @@ -63,9 +63,9 @@ func putAll[K comparable, V any](treeMap *TreeMap[K, V], m map[K]V) { // Put 在TreeMap插入指定值 // 需注意如果TreeMap已存在该Key那么原值会被替换 func (treeMap *TreeMap[K, V]) Put(key K, value V) error { - err := treeMap.Add(key, value) + err := treeMap.tree.Add(key, value) if err == tree.ErrRBTreeSameRBNode { - return treeMap.Set(key, value) + return treeMap.tree.Set(key, value) } return nil } @@ -73,13 +73,27 @@ func (treeMap *TreeMap[K, V]) Put(key K, value V) error { // Get 在TreeMap找到指定Key的节点,返回Val // TreeMap未找到指定节点将会返回false func (treeMap *TreeMap[K, V]) Get(key K) (V, bool) { - v, err := treeMap.Find(key) + v, err := treeMap.tree.Find(key) return v, err == nil } -// Remove TreeMap中删除指定key的节点 -func (treeMap *TreeMap[T, V]) Remove(k T) { - treeMap.Delete(k) +// Delete TreeMap中删除指定key的节点 +func (treeMap *TreeMap[T, V]) Delete(k T) (V, bool) { + return treeMap.tree.Delete(k) +} + +// Keys 返回了全部的键 +// 目前我们是按照中序遍历来返回的数据,但是你不能依赖于这个特性 +func (treeMap *TreeMap[T, V]) Keys() []T { + keys, _ := treeMap.tree.KeyValues() + return keys +} + +// Values 返回了全部的值 +// 目前我们是按照中序遍历来返回的数据,但是你不能依赖于这个特性 +func (treeMap *TreeMap[T, V]) Values() []V { + _, vals := treeMap.tree.KeyValues() + return vals } var _ mapi[any, any] = (*TreeMap[any, any])(nil) diff --git a/mapx/treemap_test.go b/mapx/treemap_test.go index f6e9fc86..378000fd 100644 --- a/mapx/treemap_test.go +++ b/mapx/treemap_test.go @@ -18,6 +18,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/require" + "github.com/ecodeclub/ekit" "github.com/stretchr/testify/assert" ) @@ -277,19 +279,92 @@ func TestTreeMap_Put(t *testing.T) { } } -func TestTreeMap_Remove(t *testing.T) { - var tests = []struct { +func TestTreeMap_Keys(t *testing.T) { + testCases := []struct { name string - m map[int]int - delKey int - wantVal int - wantBool bool + data map[int]int + wantKeys []int + }{ + { + name: "no data", + wantKeys: []int{}, + }, + { + name: "data", + data: map[int]int{ + 1: 11, + 2: 12, + 0: 10, + 3: 13, + 5: 15, + 4: 14, + }, + wantKeys: []int{0, 1, 2, 3, 4, 5}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tm, err := NewTreeMap[int, int](compare()) + require.NoError(t, err) + for k, v := range tc.data { + err = tm.Put(k, v) + require.NoError(t, err) + } + keys := tm.Keys() + assert.Equal(t, tc.wantKeys, keys) + }) + } +} + +func TestTreeMap_Values(t *testing.T) { + testCases := []struct { + name string + data map[int]int + wantValues []int }{ { - name: "empty-TreeMap", - m: map[int]int{}, - delKey: 0, - wantVal: 0, + name: "no data", + wantValues: []int{}, + }, + { + name: "data", + data: map[int]int{ + 1: 11, + 2: 12, + 0: 10, + 3: 13, + 5: 15, + 4: 14, + }, + wantValues: []int{10, 11, 12, 13, 14, 15}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tm, err := NewTreeMap[int, int](compare()) + require.NoError(t, err) + for k, v := range tc.data { + err = tm.Put(k, v) + require.NoError(t, err) + } + vals := tm.Values() + assert.Equal(t, tc.wantValues, vals) + }) + } +} + +func TestTreeMap_Delete(t *testing.T) { + var tests = []struct { + name string + m map[int]int + delKey int + delVal int + deleted bool + }{ + { + name: "empty-TreeMap", + m: map[int]int{}, + delKey: 0, }, { name: "find", @@ -302,7 +377,8 @@ func TestTreeMap_Remove(t *testing.T) { 4: 4, }, delKey: 2, - wantVal: 0, + deleted: true, + delVal: 2, }, { name: "not-find", @@ -314,17 +390,21 @@ func TestTreeMap_Remove(t *testing.T) { 5: 5, 4: 4, }, - delKey: 6, - wantVal: 0, + delKey: 6, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { treeMap, _ := NewTreeMap[int, int](compare()) - treeMap.Remove(tt.delKey) - val, err := treeMap.Get(tt.delKey) - assert.Equal(t, tt.wantBool, err) - assert.Equal(t, tt.wantVal, val) + for k, v := range tt.m { + err := treeMap.Put(k, v) + require.NoError(t, err) + } + delVal, ok := treeMap.Delete(tt.delKey) + assert.Equal(t, tt.deleted, ok) + assert.Equal(t, tt.delVal, delVal) + _, ok = treeMap.Get(tt.delKey) + assert.False(t, ok) }) } } diff --git a/mapx/types.go b/mapx/types.go new file mode 100644 index 00000000..8b43c7aa --- /dev/null +++ b/mapx/types.go @@ -0,0 +1,32 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +type mapi[K any, V any] interface { + Put(key K, val V) error + Get(key K) (V, bool) + // Delete 删除 + // 第一个返回值是被删除的 key 对应的值 + // 第二个返回值是代表是否真的删除了 + Delete(k K) (V, bool) + // Keys 返回所有的键 + // 注意,当你调用多次拿到的结果不一定相等 + // 取决于具体实现 + Keys() []K + // Values 返回所有的值 + // 注意,当你调用多次拿到的结果不一定相等 + // 取决于具体实现 + Values() []V +} diff --git a/pool/task_pool.go b/pool/task_pool.go index 8b0a65bb..78b34dfa 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -136,14 +136,9 @@ type OnDemandBlockTaskPool struct { // 协程id方便调试程序 id int32 - // 外部信号 - //shutdownDone chan struct{} - shutdownCtx context.Context - shutdownCancel context.CancelFunc - - // 内部中断信号 - shutdownNowCtx context.Context - shutdownNowCancel context.CancelFunc + // 中断信号 + interruptCtx context.Context + interruptCtxCancel context.CancelFunc } // NewOnDemandBlockTaskPool 创建一个新的 OnDemandBlockTaskPool @@ -165,8 +160,7 @@ func NewOnDemandBlockTaskPool(initGo int, queueSize int, opts ...option.Option[O maxIdleTime: defaultMaxIdleTime, } ctx := context.Background() - b.shutdownCtx, b.shutdownCancel = context.WithCancel(ctx) - b.shutdownNowCtx, b.shutdownNowCancel = context.WithCancel(ctx) + b.interruptCtx, b.interruptCtxCancel = context.WithCancel(ctx) atomic.StoreInt32(&b.state, stateCreated) option.Apply(b, opts...) @@ -275,25 +269,8 @@ func (b *OnDemandBlockTaskPool) trySubmit(ctx context.Context, task Task, state func (b *OnDemandBlockTaskPool) allowToCreateGoroutine() bool { b.mutex.RLock() defer b.mutex.RUnlock() - - if b.totalGo == b.maxGo { - return false - } - - // 这个判断可能太苛刻了,经常导致开协程失败,先注释掉 - // allGoShouldBeBusy := atomic.LoadInt32(&b.numGoRunningTasks) == b.totalGo - // if !allGoShouldBeBusy { - // return false - // } - rate := float64(len(b.queue)) / float64(cap(b.queue)) - if rate == 0 || rate < b.queueBacklogRate { - // log.Println("rate == 0", rate == 0, "rate", rate, " < ", b.queueBacklogRate) - return false - } - - // b.totalGo < b.maxGo && rate != 0 && rate >= b.queueBacklogRate - return true + return (b.totalGo < b.maxGo) && (rate != 0 && rate >= b.queueBacklogRate) } // Start 开始调度任务执行 @@ -316,17 +293,7 @@ func (b *OnDemandBlockTaskPool) Start() error { if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateLocked) { - n := b.initGo - - allowGo := b.maxGo - b.initGo - needGo := int32(len(b.queue)) - b.initGo - if needGo > 0 { - if needGo <= allowGo { - n += needGo - } else { - n += allowGo - } - } + n := b.numOfGoThatCanBeCreate() b.increaseTotalGo(n) for i := int32(0); i < n; i++ { @@ -338,6 +305,20 @@ func (b *OnDemandBlockTaskPool) Start() error { } } +func (b *OnDemandBlockTaskPool) numOfGoThatCanBeCreate() int32 { + n := b.initGo + allowGo := b.maxGo - b.initGo + needGo := int32(len(b.queue)) - b.initGo + if needGo > 0 { + if needGo <= allowGo { + n += needGo + } else { + n += allowGo + } + } + return n +} + func (b *OnDemandBlockTaskPool) goroutine(id int) { // 刚启动的协程除非恰巧赶上Shutdown/ShutdownNow被调用,否则应该至少执行一个task @@ -349,7 +330,7 @@ func (b *OnDemandBlockTaskPool) goroutine(id int) { for { // log.Println("id", id, "working for loop") select { - case <-b.shutdownNowCtx.Done(): + case <-b.interruptCtx.Done(): // log.Printf("id %d shutdownNow, timeoutGroup.Size=%d left\n", id, b.timeoutGroup.size()) b.decreaseTotalGo(1) return @@ -372,50 +353,42 @@ func (b *OnDemandBlockTaskPool) goroutine(id int) { } // log.Println("id", id, "out timeoutGroup") } - atomic.AddInt32(&b.numGoRunningTasks, 1) if !ok { - // b.numGoRunningTasks > 1表示虽然当前协程监听到了b.queue关闭但还有其他协程运行task,当前协程自己退出就好 - // b.numGoRunningTasks == 1表示只有当前协程"运行task"中,其他协程在一定在"拿到b.queue到已关闭",这一信号的路上 - // 绝不会处于运行task中 - if atomic.LoadInt32(&b.state) == stateClosing && atomic.CompareAndSwapInt32(&b.numGoRunningTasks, 1, 0) { - // 在b.queue关闭后,第一个检测到全部task已经自然结束的协程 - // 状态迁移 + b.decreaseTotalGo(1) + if b.numOfGo() == 0 { + // 因调用Shutdown方法导致的协程退出,最后一个退出的协程负责状态迁移及显示通知外部调用者 if atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) { - // 显示通知外部调用者 - b.shutdownCancel() + b.interruptCtxCancel() } - - b.decreaseTotalGo(1) - return } - - // 有其他协程运行task中,自己退出就好。 - atomic.AddInt32(&b.numGoRunningTasks, -1) - b.decreaseTotalGo(1) return } + // todo handle error - _ = task.Run(b.shutdownNowCtx) + atomic.AddInt32(&b.numGoRunningTasks, 1) + _ = task.Run(b.interruptCtx) atomic.AddInt32(&b.numGoRunningTasks, -1) b.mutex.Lock() // log.Println("id", id, "totalGo-mem", b.totalGo-b.timeoutGroup.size(), "totalGo", b.totalGo, "mem", b.timeoutGroup.size()) - if b.coreGo < b.totalGo && (len(b.queue) == 0 || int32(len(b.queue)) < b.totalGo) { - // 协程在(coreGo,maxGo]区间 - // 如果没有任务可以执行,或者被判定为可能抢不到任务的协程直接退出 - // 注意:一定要在此处减1才能让此刻等待在mutex上的其他协程被正确地分区 + noTasksToExecute := len(b.queue) == 0 || int32(len(b.queue)) < b.totalGo + if b.coreGo < b.totalGo && b.totalGo <= b.maxGo && noTasksToExecute { + // 当前协程属于(coreGo,maxGo]区间,发现没有任务可以执行故直接退出 + // 注意:一定要在此处减1才能让此刻等待在mutex上的其他协程被正确地划分区间 b.totalGo-- // log.Println("id", id, "exits....") b.mutex.Unlock() return } - if b.initGo < b.totalGo-b.timeoutGroup.size() /* && len(b.queue) == 0 */ { + if b.initGo < b.totalGo-b.timeoutGroup.size() { // log.Println("id", id, "initGo", b.initGo, "totalGo-mem", b.totalGo-b.timeoutGroup.size(), "totalGo", b.totalGo) - // 协程在(initGo,coreGo]区间,如果没有任务可以执行,重置计时器 - // 当len(b.queue) != 0时,即便协程属于(coreGo,maxGo]区间,也应该给它一个定时器兜底。 - // 因为现在看队列中有任务,等真去拿的时候可能恰好没任务,如果不给它一个定时器兜底此时就会出现当前协程总数长时间大于始协程数(initGo)的情况。 - // 直到队列再次有任务时才可能将当前总协程数准确无误地降至初始协程数,因此注释掉len(b.queue) == 0判断条件 + // 根据需求: + // 1. 如果当前协程属于(initGo,coreGo]区间,需要为其分配一个超时器。 + // - 当前协程在超时退出前(最大空闲时间内)尝试拿任务,拿到则继续执行,没拿到则超时退出。 + // 2. 如果当前协程属于(coreGo, maxGo]区间,且有任务可执行,也需要为其分配一个超时器兜底。 + // - 因为此时看队列中有任务,等真去拿的时候可能恰好没任务 + // - 这会导致当前协程总数(totalGo)长时间大于始协程数(initGo)直到队列再次有任务时才可能将当前总协程数准确地降至初始协程数 idleTimer = time.NewTimer(b.maxIdleTime) b.timeoutGroup.add(id) // log.Println("id", id, "add timeoutGroup", "size", b.timeoutGroup.size()) @@ -465,7 +438,7 @@ func (b *OnDemandBlockTaskPool) Shutdown() (<-chan struct{}, error) { // 先关闭等待队列不再允许提交 // 同时工作协程能够通过判断b.queue是否被关闭来终止获取任务循环 close(b.queue) - return b.shutdownCtx.Done(), nil + return b.interruptCtx.Done(), nil } } @@ -495,7 +468,7 @@ func (b *OnDemandBlockTaskPool) ShutdownNow() ([]Task, error) { close(b.queue) // 发送中断信号,中断工作协程获取任务循环 - b.shutdownNowCancel() + b.interruptCtxCancel() // 清空队列并保存 tasks := make([]Task, 0, len(b.queue)) @@ -530,11 +503,8 @@ func (b *OnDemandBlockTaskPool) States(ctx context.Context, interval time.Durati if ctx.Err() != nil { return nil, ctx.Err() } - if b.shutdownNowCtx.Err() != nil { - return nil, b.shutdownNowCtx.Err() - } - if b.shutdownCtx.Err() != nil { - return nil, b.shutdownCtx.Err() + if b.interruptCtx.Err() != nil { + return nil, b.interruptCtx.Err() } statsChan := make(chan State) @@ -549,11 +519,7 @@ func (b *OnDemandBlockTaskPool) States(ctx context.Context, interval time.Durati b.sendState(statsChan, time.Now().UnixNano()) close(statsChan) return - case <-b.shutdownNowCtx.Done(): - b.sendState(statsChan, time.Now().UnixNano()) - close(statsChan) - return - case <-b.shutdownCtx.Done(): + case <-b.interruptCtx.Done(): b.sendState(statsChan, time.Now().UnixNano()) close(statsChan) return diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 95379a67..bc9cbabd 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -29,255 +29,154 @@ import ( func TestOnDemandBlockTaskPool_States(t *testing.T) { t.Parallel() - t.Run("ctx canceled", func(t *testing.T) { - p1, err := NewOnDemandBlockTaskPool(2, 5) - assert.NoError(t, err) - testTaskPoolStatesCtxCanceled(t, p1, context.Canceled) - }) - t.Run("shutdownNowCtx canceled", func(t *testing.T) { - p1, err := NewOnDemandBlockTaskPool(2, 5) - assert.NoError(t, err) - testTaskPoolStatesShutdownNowCtxCanceled(t, p1, context.Canceled) - }) + t.Run("调用States方法时使用已取消的context应该返回错误", func(t *testing.T) { + t.Parallel() - t.Run("shutdownCtx canceled", func(t *testing.T) { - p1, err := NewOnDemandBlockTaskPool(2, 5) + pool, err := NewOnDemandBlockTaskPool(1, 3) assert.NoError(t, err) - testTaskPoolStatesShutdownCtxCanceled(t, p1, context.Canceled) - }) - t.Run("ctx Running canceled", func(t *testing.T) { - p2, err := NewOnDemandBlockTaskPool(2, 5) - assert.NoError(t, err) - testTaskPoolStatesCtxRunningCanceled(t, p2, - State{PoolState: stateRunning, GoCnt: 2, - WaitingTasksCnt: 3, QueueSize: 5, RunningTasksCnt: 2}) - }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() - t.Run("pool not running", func(t *testing.T) { - p, err := NewOnDemandBlockTaskPool(2, 5) - assert.NoError(t, err) - testTaskPoolStatesPoolNotRunning(t, p, - State{PoolState: stateCreated, GoCnt: 0, WaitingTasksCnt: 5, QueueSize: 5, RunningTasksCnt: 0}) + _, err = pool.States(ctx, time.Millisecond) + assert.Equal(t, context.Canceled, err) }) - t.Run("pool Shutdown", func(t *testing.T) { - p, err := NewOnDemandBlockTaskPool(2, 5) + t.Run("调用ShutdownNow方法后再调用States方法应该返回错误", func(t *testing.T) { + t.Parallel() + + pool, err := NewOnDemandBlockTaskPool(1, 3) assert.NoError(t, err) - testTaskPoolStatesPoolShutdown(t, p, - State{PoolState: stateClosing, GoCnt: 2, WaitingTasksCnt: 3, QueueSize: 5, RunningTasksCnt: 2}, - State{PoolState: stateStopped, GoCnt: 0, WaitingTasksCnt: 0, QueueSize: 5, RunningTasksCnt: 0}) - }) - t.Run("pool Shutdown Now", func(t *testing.T) { - p, err := NewOnDemandBlockTaskPool(1, 2) + err = pool.Start() assert.NoError(t, err) - testTaskPoolStatesPoolShutdownNow(t, p) - }) -} -func testTaskPoolStatesCtxCanceled(t *testing.T, pool *OnDemandBlockTaskPool, wantErr error) { - done := make(chan struct{}) - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) - assert.NoError(t, err) + _, err = pool.ShutdownNow() + assert.NoError(t, err) - err = pool.Start() - assert.NoError(t, err) + _, err = pool.States(context.Background(), time.Millisecond) + assert.Equal(t, context.Canceled, err) + }) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - cancel() - _, err = pool.States(ctx, time.Millisecond) - assert.Equal(t, wantErr, err) - close(done) -} + t.Run("调用Shutdown方法后再调用States方法应该返回错误", func(t *testing.T) { + t.Parallel() -func testTaskPoolStatesShutdownNowCtxCanceled(t *testing.T, pool *OnDemandBlockTaskPool, wantErr error) { - done := make(chan struct{}) - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) - assert.NoError(t, err) + pool, err := NewOnDemandBlockTaskPool(1, 3) + assert.NoError(t, err) - err = pool.Start() - assert.NoError(t, err) - done <- struct{}{} - _, err = pool.ShutdownNow() - assert.NoError(t, err) + err = pool.Start() + assert.NoError(t, err) - _, err = pool.States(context.Background(), time.Millisecond) - assert.Equal(t, wantErr, err) - close(done) -} + done, err := pool.Shutdown() + assert.NoError(t, err) -func testTaskPoolStatesShutdownCtxCanceled(t *testing.T, pool *OnDemandBlockTaskPool, wantErr error) { - done := make(chan struct{}) - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-done - return nil - })) - assert.NoError(t, err) - - err = pool.Start() - assert.NoError(t, err) - // 当 queue 里的任务为 0 个时, 调用 Shutdown() 并不会执行相应的 cancel - //done <- struct{}{} - _, err = pool.Shutdown() - assert.NoError(t, err) - done <- struct{}{} - - _, err = pool.States(context.Background(), time.Millisecond) - assert.Equal(t, wantErr, err) - close(done) -} + _, err = pool.States(context.Background(), time.Millisecond) + assert.Equal(t, context.Canceled, err) + }) -func testTaskPoolStatesCtxRunningCanceled(t *testing.T, pool *OnDemandBlockTaskPool, wantState State) { - err := pool.Start() - assert.NoError(t, err) + t.Run("调用States方法返回的chan应该能够正常读取数据", func(t *testing.T) { + t.Parallel() - done := make(chan struct{}) - n := cap(pool.queue) + pool, err := NewOnDemandBlockTaskPool(1, 3) + assert.NoError(t, err) - for i := 0; i < n; i++ { - err = pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) + ch, err := pool.States(context.Background(), time.Millisecond) assert.NoError(t, err) - } + assert.NotZero(t, <-ch) + }) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - ch, err := pool.States(ctx, time.Millisecond) - assert.NoError(t, err) - state1 := <-ch - assert.Equal(t, wantState.PoolState, state1.PoolState) - assert.Equal(t, wantState.QueueSize, state1.QueueSize) - assert.Equal(t, wantState.GoCnt, state1.GoCnt) - assert.Equal(t, wantState.WaitingTasksCnt, state1.WaitingTasksCnt) - assert.Equal(t, wantState.RunningTasksCnt, state1.RunningTasksCnt) - - cancel() - for { - state2, ok := <-ch - if !ok { - break - } - assert.Equal(t, wantState.PoolState, state2.PoolState) - assert.Equal(t, wantState.QueueSize, state2.QueueSize) - assert.Equal(t, wantState.GoCnt, state2.GoCnt) - assert.Equal(t, wantState.WaitingTasksCnt, state2.WaitingTasksCnt) - assert.Equal(t, wantState.RunningTasksCnt, state2.RunningTasksCnt) - } - close(done) -} + t.Run("当调用States方法时传入的context超时返回的chan应该被关闭", func(t *testing.T) { + t.Parallel() -func testTaskPoolStatesPoolNotRunning(t *testing.T, pool *OnDemandBlockTaskPool, wantState State) { - done := make(chan struct{}) - n := cap(pool.queue) + initGo, queueSize := 1, 3 + pool, syncChan := testNewRunningStateTaskPoolWithQueueFullFilled(t, initGo, queueSize) - for i := 0; i < n; i++ { - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) + ctx, cancel := context.WithCancel(context.Background()) + ch, err := pool.States(ctx, time.Millisecond) assert.NoError(t, err) - } - ch, err := pool.States(context.Background(), time.Millisecond) - assert.NoError(t, err) - state1 := <-ch - assert.Equal(t, wantState.PoolState, state1.PoolState) - assert.Equal(t, wantState.QueueSize, state1.QueueSize) - assert.Equal(t, wantState.GoCnt, state1.GoCnt) - assert.Equal(t, wantState.WaitingTasksCnt, state1.WaitingTasksCnt) - assert.Equal(t, wantState.RunningTasksCnt, state1.RunningTasksCnt) - close(done) -} + go func() { + // simulate timeout + <-time.After(3 * time.Millisecond) + cancel() + }() -func testTaskPoolStatesPoolShutdown(t *testing.T, pool *OnDemandBlockTaskPool, closingState, stoppedState State) { - done := make(chan struct{}) - n := cap(pool.queue) + for { + state, ok := <-ch + if !ok { + break + } + assert.NotZero(t, state) + } - for i := 0; i < n; i++ { - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) + // clean up + close(syncChan) + _, err = pool.Shutdown() assert.NoError(t, err) - } + }) - err := pool.Start() - assert.NoError(t, err) + t.Run("调用Shutdown方法应该关闭States方法返回的chan", func(t *testing.T) { + t.Parallel() - _, err = pool.Shutdown() - assert.NoError(t, err) + pool := testNewRunningStateTaskPool(t, 1, 3) - ch, err := pool.States(context.Background(), time.Millisecond) - assert.NoError(t, err) - state1 := <-ch - assert.Equal(t, closingState.PoolState, state1.PoolState) - assert.Equal(t, closingState.QueueSize, state1.QueueSize) - assert.Equal(t, closingState.GoCnt, state1.GoCnt) - assert.Equal(t, closingState.WaitingTasksCnt, state1.WaitingTasksCnt) - assert.Equal(t, closingState.RunningTasksCnt, state1.RunningTasksCnt) + ch, err := pool.States(context.Background(), time.Millisecond) + assert.NoError(t, err) - close(done) - for { - state2, ok := <-ch - if !ok { - break + go func() { + time.Sleep(5 * time.Millisecond) + _, err := pool.Shutdown() + assert.NoError(t, err) + }() + + for { + state, ok := <-ch + if !ok { + break + } + assert.NotZero(t, state) } - assert.Equal(t, stoppedState.PoolState, state2.PoolState) - assert.Equal(t, stoppedState.QueueSize, state2.QueueSize) - assert.Equal(t, stoppedState.GoCnt, state2.GoCnt) - assert.Equal(t, stoppedState.WaitingTasksCnt, state2.WaitingTasksCnt) - assert.Equal(t, stoppedState.RunningTasksCnt, state2.RunningTasksCnt) - } -} + }) -func testTaskPoolStatesPoolShutdownNow(t *testing.T, pool *OnDemandBlockTaskPool) { - done := make(chan struct{}) - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-done - return nil - })) - assert.NoError(t, err) + t.Run("调用ShutdownNow方法应该关闭States方法返回的chan", func(t *testing.T) { + t.Parallel() - err = pool.Start() - assert.NoError(t, err) + pool := testNewRunningStateTaskPool(t, 1, 3) - ch, err := pool.States(context.Background(), time.Millisecond) - assert.NoError(t, err) - done <- struct{}{} - _, err = pool.ShutdownNow() - assert.NoError(t, err) + ch, err := pool.States(context.Background(), time.Millisecond) + assert.NoError(t, err) - for { - state, ok := <-ch - if !ok { - break - } - assert.Equal(t, stateStopped, state.PoolState) - } + go func() { + time.Sleep(5 * time.Millisecond) + _, err := pool.ShutdownNow() + assert.NoError(t, err) + }() - close(done) + for { + state, ok := <-ch + if !ok { + break + } + assert.NotZero(t, state) + } + }) } /* TaskPool有限状态机 - Start/Submit/ShutdownNow() Error - \ / - Shutdown() --> CLOSING ---等待所有任务结束 - Submit()nil--执行中状态迁移--Submit() / \----------/ \----------/ - \ / \ / / -New() --> CREATED -- Start() ---> RUNNING -- -- - \ / \ / \ Start/Submit/Shutdown() Error - Shutdown/ShutdownNow()Error Start() \ \ / - ShutdownNow() ---> STOPPED -- ShutdownNow() --> STOPPED + Start/Submit/Shutdown/ShutdownNow() Error + \ / + Shutdown() --> CLOSING --> 等待所有任务结束 + States/Submit()---执行中状态迁移--States/Submit() / \ / | + \ / \ / / States() | +New() ---> CREATED ----- Start() ------> RUNNING ------ | + \ / \ / \ | + Shutdown/ShutdownNow()Error Start() \ | + ShutdownNow() ---> STOPPED <-------- | + \ / + Start/Submit/Shutdown/ShutdownNow/States() Error */ func TestOnDemandBlockTaskPool_In_Created_State(t *testing.T) { @@ -477,8 +376,10 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { }) t.Run("Start —— 在TaskPool启动前队列中已有任务,启动后不再Submit", func(t *testing.T) { + t.Parallel() t.Run("WithCoreGo,WithMaxIdleTime,所需要协程数 <= 允许创建的协程数", func(t *testing.T) { + t.Parallel() initGo, coreGo, maxIdleTime := 1, 3, 3*time.Millisecond queueSize := coreGo @@ -511,9 +412,12 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { <-wait } assert.Equal(t, int32(coreGo), pool.numOfGo()) + close(done) }) t.Run("WithMaxGo, 所需要协程数 > 允许创建的协程数", func(t *testing.T) { + t.Parallel() + initGo, maxGo := 3, 5 queueSize := maxGo + 1 @@ -545,10 +449,12 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { <-wait } assert.Equal(t, int32(maxGo), pool.numOfGo()) + close(done) }) }) t.Run("Start —— 与Submit并发调用,WithCoreGo,WithMaxIdleTime,WithMaxGo,所需要协程数 < 允许创建的协程数", func(t *testing.T) { + t.Parallel() initGo, coreGo, maxGo, maxIdleTime := 2, 4, 6, 3*time.Millisecond queueSize := coreGo @@ -588,6 +494,7 @@ func TestOnDemandBlockTaskPool_In_Running_State(t *testing.T) { } assert.Equal(t, int32(maxGo), pool.numOfGo()) + close(done) }) t.Run("Submit", func(t *testing.T) { @@ -896,21 +803,21 @@ func TestOnDemandBlockTaskPool_In_Closing_State(t *testing.T) { pool := testNewRunningStateTaskPool(t, initGo, queueSize) // 模拟阻塞提交 - n := initGo + queueSize*2 + n := initGo + queueSize + 1 eg := new(errgroup.Group) - waitChan := make(chan struct{}, n) + waitChan := make(chan struct{}) taskDone := make(chan struct{}) for i := 0; i < n; i++ { eg.Go(func() error { return pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - waitChan <- struct{}{} + <-waitChan <-taskDone return nil })) }) } for i := 0; i < initGo; i++ { - <-waitChan + waitChan <- struct{}{} } done, err := pool.Shutdown() assert.NoError(t, err) @@ -925,6 +832,7 @@ func TestOnDemandBlockTaskPool_In_Closing_State(t *testing.T) { assert.Equal(t, int32(initGo), pool.numOfGo()) + close(waitChan) close(taskDone) <-done assert.Equal(t, stateStopped, pool.internalState()) @@ -1204,6 +1112,8 @@ func testNewRunningStateTaskPoolWithQueueFullFilled(t *testing.T, initGo int, qu } func TestGroup(t *testing.T) { + t.Parallel() + n := 10 // g := &sliceGroup{members: make([]int, n, n)} diff --git a/randx/rand_code.go b/randx/rand_code.go new file mode 100644 index 00000000..776d0559 --- /dev/null +++ b/randx/rand_code.go @@ -0,0 +1,88 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package randx + +import ( + "errors" + "math/rand" +) + +var ERRTYPENOTSUPPORTTED = errors.New("ekit:不支持的类型") + +type TYPE int + +const ( + TYPE_DEFAULT TYPE = 0 //默认类型 + TYPE_DIGIT TYPE = 1 //数字// + TYPE_LETTER TYPE = 2 //小写字母 + TYPE_CAPITAL TYPE = 3 //大写字母 + TYPE_MIXED TYPE = 4 //数字+字母混合 +) + +// RandCode 根据传入的长度和类型生成随机字符串,这个方法目前可以生成数字、字母、数字+字母的随机字符串 +func RandCode(length int, typ TYPE) (string, error) { + switch typ { + case TYPE_DEFAULT: + fallthrough + case TYPE_DIGIT: + return generate("0123456789", length, 4), nil + case TYPE_LETTER: + return generate("abcdefghijklmnopqrstuvwxyz", length, 5), nil + case TYPE_CAPITAL: + return generate("ABCDEFGHIJKLMNOPQRSTUVWXYZ", length, 5), nil + case TYPE_MIXED: + return generate("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", length, 7), nil + default: + return "", ERRTYPENOTSUPPORTTED + } +} + +// generate 根据传入的随机源和长度生成随机字符串,一次随机,多次使用 +func generate(source string, length, idxBits int) string { + + //掩码 + //例如: 使用低6位:0000 0000 --> 0011 1111 + idxMask := 1<>= idxBits + + //扣减remain + remain-- + + } + return string(result) + +} diff --git a/randx/rand_code_test.go b/randx/rand_code_test.go new file mode 100644 index 00000000..ee962eee --- /dev/null +++ b/randx/rand_code_test.go @@ -0,0 +1,92 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package randx + +import ( + "errors" + "regexp" + "testing" +) + +func TestRandCode(t *testing.T) { + testCases := []struct { + name string + length int + typ TYPE + wantMatch string + wantErr error + }{ + { + name: "默认类型", + length: 8, + typ: TYPE_DEFAULT, + wantMatch: "^[0-9]+$", + wantErr: nil, + }, + { + name: "数字验证码", + length: 8, + typ: TYPE_DIGIT, + wantMatch: "^[0-9]+$", + wantErr: nil, + }, { + name: "小写字母验证码", + length: 8, + typ: TYPE_LETTER, + wantMatch: "^[a-z]+$", + wantErr: nil, + }, { + name: "大写字母验证码", + length: 8, + typ: TYPE_CAPITAL, + wantMatch: "^[A-Z]+$", + wantErr: nil, + }, { + name: "混合验证码", + length: 8, + typ: TYPE_MIXED, + wantMatch: "^[0-9a-zA-Z]+$", + wantErr: nil, + }, { + name: "未定义类型", + length: 8, + typ: 9, + wantMatch: "", + wantErr: ERRTYPENOTSUPPORTTED, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + code, err := RandCode(tc.length, tc.typ) + if err != nil { + if !errors.Is(err, tc.wantErr) { + t.Errorf("unexpected error: %v", err) + } + } else { + //长度检验 + if len(code) != tc.length { + t.Errorf("expected length: %d but got length:%d ", tc.length, len(code)) + } + //模式检验 + matched, _ := regexp.MatchString(tc.wantMatch, code) + if !matched { + t.Errorf("expected %s but got %s", tc.wantMatch, code) + } + } + }) + } + +} diff --git a/set/set.go b/set/set.go index dee45070..b1d194b3 100644 --- a/set/set.go +++ b/set/set.go @@ -16,8 +16,8 @@ package set type Set[T comparable] interface { Add(key T) - // 返回是否存在这个元素 Delete(key T) + // Exist 返回是否存在这个元素 Exist(key T) bool Keys() []T } diff --git a/set/set_test.go b/set/set_test.go index 9cf80ae2..7abd87a7 100644 --- a/set/set_test.go +++ b/set/set_test.go @@ -35,7 +35,7 @@ func TestSetx_Add(t *testing.T) { }) } -func TestSetx_Remove(t *testing.T) { +func TestSetx_Delete(t *testing.T) { testcases := []struct { name string delVal int diff --git a/set/treeset.go b/set/treeset.go index 788f84c0..1661c48c 100644 --- a/set/treeset.go +++ b/set/treeset.go @@ -48,8 +48,7 @@ func (s *TreeSet[T]) Exist(key T) bool { // Keys 方法返回的元素顺序不固定 func (s *TreeSet[T]) Keys() []T { - keys, _ := s.treeMap.KeyValues() - return keys + return s.treeMap.Keys() } var _ Set[int] = (*TreeSet[int])(nil) diff --git a/slice/add.go b/slice/add.go new file mode 100644 index 00000000..a553a9f4 --- /dev/null +++ b/slice/add.go @@ -0,0 +1,24 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import "github.com/ecodeclub/ekit/internal/slice" + +// Add 在index处添加元素 +// index 范围应为[0, len(src)) +func Add[Src any](src []Src, element Src, index int) ([]Src, error) { + res, err := slice.Add[Src](src, element, index) + return res, err +} diff --git a/slice/add_test.go b/slice/add_test.go new file mode 100644 index 00000000..078d2b35 --- /dev/null +++ b/slice/add_test.go @@ -0,0 +1,71 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "fmt" + "testing" + + "github.com/ecodeclub/ekit/internal/errs" + + "github.com/stretchr/testify/assert" +) + +func TestAdd(t *testing.T) { + // Add 主要依赖于 internal/slice.Add 来保证正确性 + testCases := []struct { + name string + slice []int + addVal int + index int + wantSlice []int + wantErr error + }{ + { + name: "index 0", + slice: []int{123, 100}, + addVal: 233, + index: 0, + wantSlice: []int{233, 123, 100}, + }, + { + name: "index -1", + slice: []int{123, 100}, + index: -1, + wantErr: errs.NewErrIndexOutOfRange(2, -1), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, err := Add(tc.slice, tc.addVal, tc.index) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, res) + }) + } +} + +func ExampleAdd() { + res, _ := Add[int]([]int{1, 2, 3, 4}, 233, 2) + fmt.Println(res) + _, err := Add[int]([]int{1, 2, 3, 4}, 233, -1) + fmt.Println(err) + // Output: + // [1 2 233 3 4] + // ekit: 下标超出范围,长度 4, 下标 -1 +} diff --git a/slice/contains.go b/slice/contains.go index 2d378a9f..e7431dc6 100644 --- a/slice/contains.go +++ b/slice/contains.go @@ -16,17 +16,17 @@ package slice // Contains 判断 src 里面是否存在 dst func Contains[T comparable](src []T, dst T) bool { - return ContainsFunc[T](src, dst, func(src, dst T) bool { + return ContainsFunc[T](src, func(src T) bool { return src == dst }) } // ContainsFunc 判断 src 里面是否存在 dst // 你应该优先使用 Contains -func ContainsFunc[T any](src []T, dst T, equal equalFunc[T]) bool { +func ContainsFunc[T any](src []T, equal func(src T) bool) bool { // 遍历调用equal函数进行判断 for _, v := range src { - if equal(v, dst) { + if equal(v) { return true } } @@ -72,7 +72,9 @@ func ContainsAll[T comparable](src, dst []T) bool { // 你应该优先使用 ContainsAll func ContainsAllFunc[T any](src, dst []T, equal equalFunc[T]) bool { for _, valDst := range dst { - if !ContainsFunc[T](src, valDst, equal) { + if !ContainsFunc[T](src, func(src T) bool { + return equal(src, valDst) + }) { return false } } diff --git a/slice/contains_test.go b/slice/contains_test.go index 015841eb..43f2e2f5 100644 --- a/slice/contains_test.go +++ b/slice/contains_test.go @@ -92,8 +92,8 @@ func TestContainsFunc(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.want, ContainsFunc[int](test.src, test.dst, func(src, dst int) bool { - return src == dst + assert.Equal(t, test.want, ContainsFunc[int](test.src, func(src int) bool { + return src == test.dst })) }) } @@ -287,8 +287,8 @@ func ExampleContains() { } func ExampleContainsFunc() { - res := ContainsFunc[int]([]int{1, 2, 3}, 3, func(src, dst int) bool { - return src == dst + res := ContainsFunc[int]([]int{1, 2, 3}, func(src int) bool { + return src == 3 }) fmt.Println(res) // Output: diff --git a/slice/diff.go b/slice/diff.go index 7d9dea36..5d12d65f 100644 --- a/slice/diff.go +++ b/slice/diff.go @@ -34,10 +34,11 @@ func DiffSet[T comparable](src, dst []T) []T { // DiffSetFunc 差集,已去重 // 你应该优先使用 DiffSet func DiffSetFunc[T any](src, dst []T, equal equalFunc[T]) []T { - // TODO 优化容量预估 var ret = make([]T, 0, len(src)) for _, val := range src { - if !ContainsFunc[T](dst, val, equal) { + if !ContainsFunc[T](dst, func(src T) bool { + return equal(src, val) + }) { ret = append(ret, val) } } diff --git a/slice/find.go b/slice/find.go new file mode 100644 index 00000000..88677f68 --- /dev/null +++ b/slice/find.go @@ -0,0 +1,43 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +// Find 查找元素 +// 如果没有找到,第二个返回值返回 false +func Find[T any](src []T, match matchFunc[T]) (T, bool) { + for _, val := range src { + if match(val) { + return val, true + } + } + var t T + return t, false +} + +// FindAll 查找所有符合条件的元素 +// 永远不会返回 nil +func FindAll[T any](src []T, match matchFunc[T]) []T { + // 我们认为符合条件元素应该是少数 + // 所以会除以 8 + // 也就是触发扩容的情况下,最多三次就会和原本的容量一样 + // +1 是为了保证,至少有一个元素 + res := make([]T, 0, len(src)>>3+1) + for _, val := range src { + if match(val) { + res = append(res, val) + } + } + return res +} diff --git a/slice/find_test.go b/slice/find_test.go new file mode 100644 index 00000000..90f2d30e --- /dev/null +++ b/slice/find_test.go @@ -0,0 +1,149 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFind(t *testing.T) { + testCases := []struct { + name string + input []Number + match matchFunc[Number] + + wantVal Number + found bool + }{ + { + name: "找到了", + input: []Number{ + {val: 123}, + {val: 234}, + }, + match: func(src Number) bool { + return src.val == 123 + }, + wantVal: Number{val: 123}, + found: true, + }, + { + name: "没找到", + input: []Number{ + {val: 123}, + {val: 234}, + }, + match: func(src Number) bool { + return src.val == 456 + }, + }, + { + name: "nil", + match: func(src Number) bool { + return src.val == 123 + }, + }, + { + name: "没有元素", + input: []Number{}, + match: func(src Number) bool { + return src.val == 123 + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val, found := Find[Number](tc.input, tc.match) + assert.Equal(t, tc.found, found) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestFindAll(t *testing.T) { + testCases := []struct { + name string + input []Number + match matchFunc[Number] + + wantVals []Number + }{ + { + name: "没有符合条件的", + input: []Number{{val: 2}, {val: 4}}, + match: func(src Number) bool { + return src.val%2 == 1 + }, + wantVals: []Number{}, + }, + { + name: "找到了", + input: []Number{{val: 2}, {val: 3}, {val: 4}}, + match: func(src Number) bool { + return src.val%2 == 1 + }, + wantVals: []Number{{val: 3}}, + }, + { + name: "nil", + match: func(src Number) bool { + return src.val%2 == 1 + }, + wantVals: []Number{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vals := FindAll[Number](tc.input, tc.match) + assert.Equal(t, tc.wantVals, vals) + }) + } +} + +func ExampleFind() { + val, ok := Find[int]([]int{1, 2, 3}, func(src int) bool { + return src == 2 + }) + fmt.Println(val, ok) + val, ok = Find[int]([]int{1, 2, 3}, func(src int) bool { + return src == 4 + }) + fmt.Println(val, ok) + // Output: + // 2 true + // 0 false +} + +func ExampleFindAll() { + vals := FindAll[int]([]int{2, 3, 4}, func(src int) bool { + return src%2 == 1 + }) + fmt.Println(vals) + vals = FindAll[int]([]int{2, 3, 4}, func(src int) bool { + return src > 5 + }) + fmt.Println(vals) + // Output: + // [3] + // [] +} + +type Number struct { + val int +} diff --git a/slice/index.go b/slice/index.go index 3da1034c..029bd49d 100644 --- a/slice/index.go +++ b/slice/index.go @@ -17,17 +17,17 @@ package slice // Index 返回和 dst 相等的第一个元素下标 // -1 表示没找到 func Index[T comparable](src []T, dst T) int { - return IndexFunc[T](src, dst, func(src, dst T) bool { + return IndexFunc[T](src, func(src T) bool { return src == dst }) } -// IndexFunc 返回和 dst 相等的第一个元素下标 +// IndexFunc 返回 match 返回 true 的第一个下标 // -1 表示没找到 // 你应该优先使用 Index -func IndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { +func IndexFunc[T any](src []T, match matchFunc[T]) int { for k, v := range src { - if equal(v, dst) { + if match(v) { return k } } @@ -37,7 +37,7 @@ func IndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { // LastIndex 返回和 dst 相等的最后一个元素下标 // -1 表示没找到 func LastIndex[T comparable](src []T, dst T) int { - return LastIndexFunc[T](src, dst, func(src, dst T) bool { + return LastIndexFunc[T](src, func(src T) bool { return src == dst }) } @@ -45,9 +45,9 @@ func LastIndex[T comparable](src []T, dst T) int { // LastIndexFunc 返回和 dst 相等的最后一个元素下标 // -1 表示没找到 // 你应该优先使用 LastIndex -func LastIndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { +func LastIndexFunc[T any](src []T, match matchFunc[T]) int { for i := len(src) - 1; i >= 0; i-- { - if equal(dst, src[i]) { + if match(src[i]) { return i } } @@ -56,17 +56,17 @@ func LastIndexFunc[T any](src []T, dst T, equal equalFunc[T]) int { // IndexAll 返回和 dst 相等的所有元素的下标 func IndexAll[T comparable](src []T, dst T) []int { - return IndexAllFunc[T](src, dst, func(src, dst T) bool { + return IndexAllFunc[T](src, func(src T) bool { return src == dst }) } -// IndexAllFunc 返回和 dst 相等的所有元素的下标 +// IndexAllFunc 返回和 match 返回 true 的所有元素的下标 // 你应该优先使用 IndexAll -func IndexAllFunc[T any](src []T, dst T, equal equalFunc[T]) []int { +func IndexAllFunc[T any](src []T, match matchFunc[T]) []int { var indexes = make([]int, 0, len(src)) for k, v := range src { - if equal(v, dst) { + if match(v) { indexes = append(indexes, k) } } diff --git a/slice/index_test.go b/slice/index_test.go index 03c1199c..c29bdfaf 100644 --- a/slice/index_test.go +++ b/slice/index_test.go @@ -104,8 +104,8 @@ func TestIndexFunc(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.want, IndexFunc[int](test.src, test.dst, func(src, dst int) bool { - return src == dst + assert.Equal(t, test.want, IndexFunc[int](test.src, func(src int) bool { + return src == test.dst })) }) } @@ -191,8 +191,8 @@ func TestLastIndexFunc(t *testing.T) { }, } for _, test := range tests { - assert.Equal(t, test.want, LastIndexFunc[int](test.src, test.dst, func(src, dst int) bool { - return src == dst + assert.Equal(t, test.want, LastIndexFunc[int](test.src, func(src int) bool { + return src == test.dst })) } } @@ -268,8 +268,8 @@ func TestIndexAllFunc(t *testing.T) { }, } for _, test := range tests { - res := IndexAllFunc[int](test.src, test.dst, func(src, dst int) bool { - return src == dst + res := IndexAllFunc[int](test.src, func(src int) bool { + return src == test.dst }) assert.ElementsMatch(t, test.want, res) } @@ -286,12 +286,12 @@ func ExampleIndex() { } func ExampleIndexFunc() { - res := IndexFunc[int]([]int{1, 2, 3}, 1, func(src, dst int) bool { - return src == dst + res := IndexFunc[int]([]int{1, 2, 3}, func(src int) bool { + return src == 1 }) fmt.Println(res) - res = IndexFunc[int]([]int{1, 2, 3}, 4, func(src, dst int) bool { - return src == dst + res = IndexFunc[int]([]int{1, 2, 3}, func(src int) bool { + return src == 4 }) fmt.Println(res) // Output: @@ -310,12 +310,12 @@ func ExampleIndexAll() { } func ExampleIndexAllFunc() { - res := IndexAllFunc[int]([]int{1, 2, 3, 4, 5, 3, 9}, 3, func(src, dst int) bool { - return src == dst + res := IndexAllFunc[int]([]int{1, 2, 3, 4, 5, 3, 9}, func(src int) bool { + return src == 3 }) fmt.Println(res) - res = IndexAllFunc[int]([]int{1, 2, 3}, 4, func(src, dst int) bool { - return src == dst + res = IndexAllFunc[int]([]int{1, 2, 3}, func(src int) bool { + return src == 4 }) fmt.Println(res) // Output: diff --git a/slice/intersect.go b/slice/intersect.go index f82edeff..6ba45050 100644 --- a/slice/intersect.go +++ b/slice/intersect.go @@ -33,12 +33,11 @@ func IntersectSet[T comparable](src []T, dst []T) []T { // 已去重 func IntersectSetFunc[T any](src []T, dst []T, equal equalFunc[T]) []T { var ret = make([]T, 0, len(src)) - for _, valSrc := range src { - for _, valDst := range dst { - if equal(valDst, valSrc) { - ret = append(ret, valSrc) - break - } + for _, v := range dst { + if ContainsFunc[T](src, func(t T) bool { + return equal(t, v) + }) { + ret = append(ret, v) } } return deduplicateFunc[T](ret, equal) diff --git a/slice/map.go b/slice/map.go index 7732de48..11ce496f 100644 --- a/slice/map.go +++ b/slice/map.go @@ -49,7 +49,9 @@ func toMap[T comparable](src []T) map[T]struct{} { func deduplicateFunc[T any](data []T, equal equalFunc[T]) []T { var newData = make([]T, 0, len(data)) for k, v := range data { - if !ContainsFunc[T](data[k+1:], v, equal) { + if !ContainsFunc[T](data[k+1:], func(src T) bool { + return equal(src, v) + }) { newData = append(newData, v) } } diff --git a/slice/symmetric_diff.go b/slice/symmetric_diff.go index 2b970cc4..f31dab87 100644 --- a/slice/symmetric_diff.go +++ b/slice/symmetric_diff.go @@ -19,56 +19,45 @@ package slice // 返回值的元素顺序是不定的 func SymmetricDiffSet[T comparable](src, dst []T) []T { srcMap, dstMap := toMap[T](src), toMap[T](dst) - for dstKey := range dstMap { - if _, exist := srcMap[dstKey]; exist { - // 删除共同元素,两者剩余的并集即为对称差 - delete(dstMap, dstKey) - delete(srcMap, dstKey) + for k := range dstMap { + if _, ok := srcMap[k]; ok { + delete(srcMap, k) + } else { + srcMap[k] = struct{}{} } } - for k, v := range dstMap { - srcMap[k] = v - } - var ret = make([]T, 0, len(srcMap)) + res := make([]T, 0, len(srcMap)) for k := range srcMap { - ret = append(ret, k) + res = append(res, k) } - return ret + return res } // SymmetricDiffSetFunc 对称差集 // 你应该优先使用 SymmetricDiffSet // 已去重 func SymmetricDiffSetFunc[T any](src, dst []T, equal equalFunc[T]) []T { - var interSection = make([]T, 0, min(len(src), len(dst))) - for _, valSrc := range src { - for _, valDst := range dst { - if equal(valSrc, valDst) { - interSection = append(interSection, valSrc) - break - } - } - } + res := []T{} - ret := make([]T, 0, len(src)+len(dst)-len(interSection)*2) + //找出在src不在dst的元素 for _, v := range src { - if !ContainsFunc[T](interSection, v, equal) { - ret = append(ret, v) + if !ContainsFunc[T](dst, func(t T) bool { + return equal(t, v) + }) { + res = append(res, v) } } + + //找出在dst不在src的元素 for _, v := range dst { - if !ContainsFunc[T](interSection, v, equal) { - ret = append(ret, v) + if !ContainsFunc[T](src, func(t T) bool { + return equal(t, v) + }) { + res = append(res, v) } } - return deduplicateFunc[T](ret, equal) -} -func min(src, dst int) int { - if src > dst { - return dst - } - return src + return deduplicateFunc[T](res, equal) } diff --git a/slice/symmetric_diff_test.go b/slice/symmetric_diff_test.go index 7ba0f9da..59060c7b 100644 --- a/slice/symmetric_diff_test.go +++ b/slice/symmetric_diff_test.go @@ -30,31 +30,70 @@ func TestSymmetricDiffSet(t *testing.T) { want []int }{ { - src: []int{1, 2, 4, 3}, - dst: []int{4, 5, 6, 1}, - want: []int{2, 3, 5, 6}, - name: "normal test", + name: "no inter", + src: []int{1, 2, 3}, + dst: []int{4, 5, 6}, + want: []int{1, 2, 3, 4, 5, 6}, }, { - src: []int{1, 1, 2, 3, 4}, - dst: []int{4, 5, 6, 1, 7, 6}, - want: []int{3, 6, 7, 5, 2}, - name: "deduplicate", + name: "part inter", + src: []int{1, 2, 3}, + dst: []int{3, 4, 5}, + want: []int{1, 2, 4, 5}, }, { - src: []int{}, - dst: []int{1}, + name: "src contain dst", + src: []int{1, 2, 3}, + dst: []int{2, 3}, want: []int{1}, - name: "src length is 0", }, { - src: []int{1, 3, 5}, - dst: []int{2, 4}, - want: []int{1, 3, 2, 4, 5}, - name: "not exist same ele", + name: "dst contain src", + src: []int{4}, + dst: []int{4, 5, 6}, + want: []int{5, 6}, + }, + { + name: "equal", + src: []int{1, 2, 3}, + dst: []int{1, 2, 3}, + want: []int{}, + }, + { + name: "dst empty", + src: []int{1, 2, 3}, + dst: []int{}, + want: []int{1, 2, 3}, + }, + { + name: "src empty", + src: []int{}, + dst: []int{4, 5, 6}, + want: []int{4, 5, 6}, + }, + { + name: "all empty", + src: []int{}, + dst: []int{}, + want: []int{}, + }, + { + name: "src nil", + src: nil, + dst: []int{4, 5, 6}, + want: []int{4, 5, 6}, + }, + { + name: "dst nil", + src: []int{4, 5, 6}, + dst: nil, + want: []int{4, 5, 6}, }, { name: "both nil", + src: nil, + dst: nil, + want: []int{}, }, } for _, tt := range tests { @@ -73,31 +112,70 @@ func TestSymmetricDiffSetFunc(t *testing.T) { want []int }{ { - src: []int{1, 2, 3, 4}, - dst: []int{4, 5, 6, 1}, - want: []int{2, 3, 5, 6}, - name: "normal test", + name: "no inter", + src: []int{1, 2, 3}, + dst: []int{4, 5, 6}, + want: []int{1, 2, 3, 4, 5, 6}, }, { - src: []int{1, 1, 2, 3, 4}, - dst: []int{4, 5, 6, 1, 7, 6}, - want: []int{3, 6, 7, 5, 2}, - name: "deduplicate", + name: "part inter", + src: []int{1, 2, 3}, + dst: []int{3, 4, 5}, + want: []int{1, 2, 4, 5}, }, { - src: []int{}, - dst: []int{1}, + name: "src contain dst", + src: []int{1, 2, 3}, + dst: []int{2, 3}, want: []int{1}, - name: "src length is 0", }, { - src: []int{1, 3, 5}, - dst: []int{2, 4}, - want: []int{1, 3, 2, 4, 5}, - name: "not exist same ele", + name: "dst contain src", + src: []int{4}, + dst: []int{4, 5, 6}, + want: []int{5, 6}, + }, + { + name: "equal", + src: []int{1, 2, 3}, + dst: []int{1, 2, 3}, + want: []int{}, + }, + { + name: "dst empty", + src: []int{1, 2, 3}, + dst: []int{}, + want: []int{1, 2, 3}, + }, + { + name: "src empty", + src: []int{}, + dst: []int{4, 5, 6}, + want: []int{4, 5, 6}, + }, + { + name: "all empty", + src: []int{}, + dst: []int{}, + want: []int{}, + }, + { + name: "src nil", + src: nil, + dst: []int{4, 5, 6}, + want: []int{4, 5, 6}, + }, + { + name: "dst nil", + src: []int{4, 5, 6}, + dst: nil, + want: []int{4, 5, 6}, }, { name: "both nil", + src: nil, + dst: nil, + want: []int{}, }, } for _, tt := range tests { diff --git a/slice/types.go b/slice/types.go index 79014b45..9e32fff1 100644 --- a/slice/types.go +++ b/slice/types.go @@ -16,3 +16,5 @@ package slice // equalFunc 比较两个元素是否相等 type equalFunc[T any] func(src, dst T) bool + +type matchFunc[T any] func(src T) bool diff --git a/sqlx/encrypt.go b/sqlx/encrypt.go index a932179b..0fdee3c0 100644 --- a/sqlx/encrypt.go +++ b/sqlx/encrypt.go @@ -93,9 +93,6 @@ func (e *EncryptColumn[T]) Scan(src any) error { b, err = e.aesDecrypt(value) case string: b, err = e.aesDecrypt([]byte(value)) - if err != nil { - return nil - } default: return fmt.Errorf("ekit:EncryptColumn.Scan 不支持 src 类型 %v", src) } diff --git a/sqlx/scanner.go b/sqlx/scanner.go new file mode 100644 index 00000000..5b1b296e --- /dev/null +++ b/sqlx/scanner.go @@ -0,0 +1,106 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import ( + "errors" + "fmt" + "reflect" +) + +var ( + ErrNoMoreRows = errors.New("ekit: 已读取完") + errInvalidArgument = errors.New("ekit: 参数非法") + _ Scanner = &sqlRowsScanner{} +) + +// Scanner 用于简化sql.Rows包中的Scan操作 +// Scanner 不会关闭sql.Rows,用户需要对此负责 +type Scanner interface { + Scan() (values []any, err error) + // ScanAll 扫描当前结果集的全部数据 + ScanAll() (allValues [][]any, err error) + // NextResultSet 移动到下一个结果集 + NextResultSet() bool +} + +type sqlRowsScanner struct { + sqlRows Rows + columnValuePointers []any +} + +// NewSQLRowsScanner 返回一个Scanner +func NewSQLRowsScanner(r Rows) (Scanner, error) { + if r == nil { + return nil, fmt.Errorf("%w *sql.Rows不能为nil", errInvalidArgument) + } + columnTypes, err := r.ColumnTypes() + if err != nil || len(columnTypes) < 1 { + return nil, fmt.Errorf("%w 无法获取*sql.Rows列类型信息: %v", errInvalidArgument, err) + } + columnValuePointers := make([]any, len(columnTypes)) + for i, columnType := range columnTypes { + typ := columnType.ScanType() + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + columnValuePointers[i] = reflect.New(typ).Interface() + } + return &sqlRowsScanner{sqlRows: r, columnValuePointers: columnValuePointers}, nil +} + +func (s *sqlRowsScanner) NextResultSet() bool { + return s.sqlRows.NextResultSet() +} + +// Scan 返回一行 +func (s *sqlRowsScanner) Scan() ([]any, error) { + if !s.sqlRows.Next() { + if err := s.sqlRows.Err(); err != nil { + return nil, err + } + + return nil, fmt.Errorf("%w", ErrNoMoreRows) + } + err := s.sqlRows.Scan(s.columnValuePointers...) + if err != nil { + return nil, err + } + return s.columnValues(), nil +} + +func (s *sqlRowsScanner) columnValues() []any { + values := make([]any, len(s.columnValuePointers)) + for i := 0; i < len(s.columnValuePointers); i++ { + values[i] = reflect.ValueOf(s.columnValuePointers[i]).Elem().Interface() + } + return values +} + +// ScanAll 返回所有行 +func (s *sqlRowsScanner) ScanAll() ([][]any, error) { + all := make([][]any, 0, 32) + for { + columnValues, err := s.Scan() + if err != nil { + if errors.Is(err, ErrNoMoreRows) { + break + } + return nil, err + } + all = append(all, columnValues) + } + return all, nil +} diff --git a/sqlx/scanner_test.go b/sqlx/scanner_test.go new file mode 100644 index 00000000..77756c0d --- /dev/null +++ b/sqlx/scanner_test.go @@ -0,0 +1,296 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSqlRowsScanner_New(t *testing.T) { + t.Parallel() + t.Run("当*sql.Rows为nil时,应该报错", func(t *testing.T) { + t.Parallel() + _, err := NewSQLRowsScanner(nil) + require.ErrorIs(t, err, errInvalidArgument) + }) + t.Run("当无法获取*sql.Rows列类型信息时,应该报错", func(t *testing.T) { + t.Parallel() + t.Run("*sql.Rows已关闭", func(t *testing.T) { + t.Parallel() + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + rows, err := db.QueryContext(context.Background(), "") + require.NoError(t, err) + require.NoError(t, rows.Close()) + + _, err = NewSQLRowsScanner(rows) + assert.Error(t, err) + }) + t.Run("*sql.Rows无列类型信息", func(t *testing.T) { + t.Parallel() + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + rows, err := db.QueryContext(context.Background(), "") + require.NoError(t, err) + + _, err = NewSQLRowsScanner(rows) + assert.ErrorIs(t, err, errInvalidArgument) + }) + }) +} + +func TestSqlRowsScanner_Scan(t *testing.T) { + db, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") + require.NoError(t, err) + defer db.Close() + + query := "DROP TABLE IF EXISTS t1; CREATE TABLE t1 (\n id int primary key,\n `int` int,\n `integer` integer,\n `tinyint` TINYINT,\n `smallint` smallint,\n `MEDIUMINT` MEDIUMINT,\n `BIGINT` BIGINT,\n `UNSIGNED_BIG_INT` UNSIGNED BIG INT,\n `INT2` INT2,\n `INT8` INT8,\n `VARCHAR` VARCHAR(20),\n \t\t`CHARACTER` CHARACTER(20),\n `VARYING_CHARACTER` VARYING_CHARACTER(20),\n `NCHAR` NCHAR(23),\n `TEXT` TEXT,\n `CLOB` CLOB,\n `REAL` REAL,\n `DOUBLE` DOUBLE,\n `DOUBLE_PRECISION` DOUBLE PRECISION,\n `FLOAT` FLOAT,\n `DATETIME` DATETIME \n );" + _, err = db.ExecContext(context.Background(), query) + require.NoError(t, err) + + tests := []struct { + name string + rows *sql.Rows + want []any + cleanup func() + }{ + { + name: "浮点类型", + rows: func() *sql.Rows { + res, er := db.Exec("INSERT INTO `t1` (`REAL`,`DOUBLE`,`DOUBLE_PRECISION`, `FLOAT`) VALUES (1.0,1.0,1.0,0);") + require.NoError(t, er) + id, _ := res.LastInsertId() + q := "SELECT `REAL`,`DOUBLE`,`DOUBLE_PRECISION`,`FLOAT` FROM `t1` WHERE id=?;" + rows, er := db.QueryContext(context.Background(), q, id) + require.NoError(t, er) + return rows + }(), + want: []any{sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 0}}, + cleanup: func() { + _, er := db.Exec("DELETE FROM `t1`") + require.NoError(t, er) + }, + }, + { + name: "整型", + rows: func() *sql.Rows { + res, er := db.Exec("INSERT INTO `t1` (`int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`, `INT8`) VALUES (1,1,1,1,1,1,1,1,1);") + require.NoError(t, er) + q := "SELECT `int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`,`INT8` FROM `t1` WHERE id=?;" + id, _ := res.LastInsertId() + rows, er := db.QueryContext(context.Background(), q, id) + require.NoError(t, er) + return rows + }(), + want: []any{sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}}, + cleanup: func() { + _, er := db.Exec("DELETE FROM `t1`") + require.NoError(t, er) + }, + }, + { + name: "string类型", + rows: func() *sql.Rows { + res, er := db.Exec("INSERT INTO `t1` (`VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`) VALUES ('zwl','zwl','zwl','zwl','zwl');") + require.NoError(t, er) + id, _ := res.LastInsertId() + q := "SELECT `VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`,`CLOB` FROM `t1` WHERE id=?;" + rows, er := db.QueryContext(context.Background(), q, id) + require.NoError(t, er) + return rows + }(), + want: []any{sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: false, String: ""}}, + cleanup: func() { + _, er := db.Exec("DELETE FROM `t1`") + require.NoError(t, er) + }, + }, + { + name: "时间类型", + rows: func() *sql.Rows { + res, er := db.Exec("INSERT INTO `t1` (`DATETIME`) VALUES ('2022-01-01 12:00:00');") + require.NoError(t, er) + id, _ := res.LastInsertId() + q := "SELECT `DATETIME` FROM `t1` WHERE id=?;" + rows, er := db.QueryContext(context.Background(), q, id) + require.NoError(t, er) + return rows + }(), + want: []any{sql.NullTime{Valid: true, Time: func() time.Time { + tim, er := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, er) + return tim + }()}}, + cleanup: func() { + _, er := db.Exec("DELETE FROM `t1`") + require.NoError(t, er) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, err := NewSQLRowsScanner(tt.rows) + require.NoError(t, err) + for { + got, err := s.Scan() + if err != nil && errors.Is(err, ErrNoMoreRows) { + break + } + assert.NoError(t, err) + assert.Equalf(t, tt.want, got, "ScanRows(%v)", tt.rows) + } + tt.cleanup() + }) + } + + t.Run("迭代期间sql.Rows发生错误,Scan应该报错", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + expectedErr := errors.New("iteration error") + mock.ExpectQuery("SELECT").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "John"). + AddRow(2, "Jane").RowError(1, expectedErr)) + + rows, err := db.Query("SELECT id, name FROM users") + require.NoError(t, err) + defer rows.Close() + + s, err := NewSQLRowsScanner(rows) + require.NoError(t, err) + + values, err := s.Scan() + assert.NoError(t, err) + assert.Equal(t, []any{int64(1), "John"}, values) + + _, err = s.Scan() + assert.Equal(t, expectedErr, err) + }) +} + +func TestSqlRowsScanner_ScanAll(t *testing.T) { + t.Parallel() + t.Run("迭代期间sql.Rows没有错误,ScanAll正常结束", func(t *testing.T) { + t.Parallel() + db, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") + require.NoError(t, err) + defer db.Close() + + query := "DROP TABLE IF EXISTS t1; CREATE TABLE t1 " + + "(id int primary key," + + "`name` VARCHAR(20), " + + "`intro` TEXT, " + + "`create_time` DATETIME);" + _, err = db.ExecContext(context.Background(), query) + require.NoError(t, err) + + t1, _ := time.ParseInLocation("2006-01-02 15:04:05", "2023-02-01 19:00:01", time.UTC) + t2, _ := time.ParseInLocation("2006-01-02 15:04:05", "2023-04-01 11:00:00", time.UTC) + t3, _ := time.ParseInLocation("2006-01-02 15:04:05", "2023-02-02 09:00:23", time.UTC) + t4, _ := time.ParseInLocation("2006-01-02 15:04:05", "2023-02-04 15:00:00", time.UTC) + + _, err = db.Exec("INSERT INTO `t1` (`id`, `name`, `intro`, `create_time`) VALUES " + + "(1, 'zhangsan','这是一段中文介绍', \"2023-02-01 19:00:01\"), " + + "(2, 'lisi','这是一段中文介绍', \"2023-04-01 11:00:00\"), " + + "(3, 'wangwu','this is English introduction', \"2023-02-02 09:00:23\"), " + + "(4, 'zhaoliu','this is English introduction', \"2023-02-04 15:00:00\");") + require.NoError(t, err) + + expected := [][]any{ + {sql.NullInt64{Valid: true, Int64: 1}, sql.NullString{Valid: true, String: "zhangsan"}, sql.NullString{Valid: true, String: "这是一段中文介绍"}, sql.NullTime{Valid: true, Time: t1}}, + {sql.NullInt64{Valid: true, Int64: 2}, sql.NullString{Valid: true, String: "lisi"}, sql.NullString{Valid: true, String: "这是一段中文介绍"}, sql.NullTime{Valid: true, Time: t2}}, + {sql.NullInt64{Valid: true, Int64: 3}, sql.NullString{Valid: true, String: "wangwu"}, sql.NullString{Valid: true, String: "this is English introduction"}, sql.NullTime{Valid: true, Time: t3}}, + {sql.NullInt64{Valid: true, Int64: 4}, sql.NullString{Valid: true, String: "zhaoliu"}, sql.NullString{Valid: true, String: "this is English introduction"}, sql.NullTime{Valid: true, Time: t4}}, + } + + rows, err := db.QueryContext(context.Background(), "SELECT * FROM `t1`;") + require.NoError(t, err) + defer rows.Close() + + s, err := NewSQLRowsScanner(rows) + require.NoError(t, err) + + actual, err := s.ScanAll() + assert.NoError(t, err) + assert.Equal(t, expected, actual) + }) + t.Run("迭代期间sql.Rows发生错误,ScanAll应该报错", func(t *testing.T) { + t.Parallel() + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + expectedErr := errors.New("iteration error") + + mock.ExpectQuery("SELECT").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "John"). + AddRow(2, "Jane").RowError(1, expectedErr)) + + rows, err := db.Query("SELECT id, name FROM users") + require.NoError(t, err) + defer rows.Close() + + s, err := NewSQLRowsScanner(rows) + require.NoError(t, err) + + _, err = s.ScanAll() + assert.Equal(t, expectedErr, err) + }) +} + +func TestSqlRowsScanner_NextResultSet(t *testing.T) { + t.Parallel() + t.Run("没有更多 ResultSet", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + mock.ExpectQuery("SELECT .*"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + rows, err := db.Query("SELECT id, name FROM users") + require.NoError(t, err) + scanner, err := NewSQLRowsScanner(rows) + require.NoError(t, err) + assert.False(t, scanner.NextResultSet()) + }) + t.Run("还有 ResultSet", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + mock.ExpectQuery("SELECT .*"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}), + sqlmock.NewRows([]string{"id", "name"}), + sqlmock.NewRows([]string{"id", "name"})) + rows, err := db.Query("SELECT id, name FROM users") + require.NoError(t, err) + scanner, err := NewSQLRowsScanner(rows) + require.NoError(t, err) + assert.True(t, scanner.NextResultSet()) + assert.True(t, scanner.NextResultSet()) + assert.False(t, scanner.NextResultSet()) + }) +} diff --git a/sqlx/types.go b/sqlx/types.go new file mode 100644 index 00000000..c68eb556 --- /dev/null +++ b/sqlx/types.go @@ -0,0 +1,36 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import "database/sql" + +// 因为 sql 包里面缺乏顶级接口定义,而在研发一些中间件的时候,又必须用到不同的实现 +// 因此在这里提前定义一些顶级接口 +// 一般来说,如果你不是设计一些和数据库有关的中间件,你是用不上这些接口的 + +var _ Rows = (*sql.Rows)(nil) + +type Rows interface { + Next() bool + NextResultSet() bool + Err() error + Columns() ([]string, error) + // ColumnTypes 还是返回了原本的 sql.ColumnType + // 因为 ColumnType 同样不是一个接口,所以为了兼容 sql.Rows, + // 就只有保持这个设计 + ColumnTypes() ([]*sql.ColumnType, error) + Scan(dest ...any) error + Close() error +} diff --git a/stringx/string.go b/stringx/string.go new file mode 100644 index 00000000..91572221 --- /dev/null +++ b/stringx/string.go @@ -0,0 +1,37 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stringx + +import ( + "unsafe" +) + +// 确保传入的字符串和字节切片的生命周期足够长,不会在转换后被释放或修改。 +// 确保传入的字符串和字节切片的长度和容量是一致的,否则可能导致访问越界。 +// 不要对转换后的字节切片或字符串进行修改,因为它们可能与原始的字符串或字节切片共享底层的内存。 + +// UnsafeToBytes 非安全 string 转 []byte 他必须遵守上述规则 +func UnsafeToBytes(val string) []byte { + sh := (*[2]uintptr)(unsafe.Pointer(&val)) + bh := [3]uintptr{sh[0], sh[1], sh[1]} + return *(*[]byte)(unsafe.Pointer(&bh)) +} + +// UnsafeToString 非安全 []byte 转 string 他必须遵守上述规则 +func UnsafeToString(val []byte) string { + bh := (*[3]uintptr)(unsafe.Pointer(&val)) + sh := [2]uintptr{bh[0], bh[1]} + return *(*string)(unsafe.Pointer(&sh)) +} diff --git a/stringx/string_test.go b/stringx/string_test.go new file mode 100644 index 00000000..6f773813 --- /dev/null +++ b/stringx/string_test.go @@ -0,0 +1,160 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stringx + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnsafeToBytes(t *testing.T) { + testCase := []struct { + name string + val string + want []byte + }{ + { + name: "normal conversion", + val: "hello", + want: []byte("hello"), + }, + { + name: "emoji coversion", + val: "😀!hello world", + want: []byte("😀!hello world"), + }, + { + name: "chinese coversion", + val: "你好 世界!", + want: []byte("你好 世界!"), + }, + } + + for _, tt := range testCase { + t.Run(tt.name, func(t *testing.T) { + val := UnsafeToBytes(tt.val) + assert.Equal(t, tt.want, val) + }) + } +} + +func TestUnsafeToString(t *testing.T) { + testCase := []struct { + name string + before func(t *testing.T) + after func(t *testing.T) + val func(t *testing.T) []byte + want string + }{ + { + name: "normal conversion", + before: func(t *testing.T) {}, + after: func(t *testing.T) {}, + val: func(t *testing.T) []byte { + return []byte("hello") + }, + want: "hello", + }, + { + name: "emoji coversion", + before: func(t *testing.T) {}, + after: func(t *testing.T) {}, + val: func(t *testing.T) []byte { + return []byte("😀!hello world") + }, + want: "😀!hello world", + }, + { + name: "chinese coversion", + before: func(t *testing.T) {}, + after: func(t *testing.T) {}, + val: func(t *testing.T) []byte { + return []byte("你好 世界!") + }, + want: "你好 世界!", + }, + { + // 通过读取 file 文件 模拟 io.Reader 中存在的字节流 并将其转换为 string 检查他的正确性 + // 当然他必须是可控制的 + name: "file(io.Reader) read bytes stream coversion string", + before: func(t *testing.T) { + create, err := os.Create("/tmp/test_put.txt") + require.NoError(t, err) + defer create.Close() + _, err = create.WriteString("the test file...") + require.NoError(t, err) + }, + after: func(t *testing.T) { + require.NoError(t, os.Remove("/tmp/test_put.txt")) + }, + val: func(t *testing.T) []byte { + open, err := os.Open("/tmp/test_put.txt") + require.NoError(t, err) + defer open.Close() + buf := bytes.Buffer{} + _, err = buf.ReadFrom(open) + require.NoError(t, err) + return buf.Bytes() + }, + want: "the test file...", + }, + } + + for _, tt := range testCase { + t.Run(tt.name, func(t *testing.T) { + defer tt.after(t) + tt.before(t) + b := tt.val(t) + val := UnsafeToString(b) + assert.Equal(t, tt.want, val) + }) + } +} + +func Benchmark_UnsafeToBytes(b *testing.B) { + b.Run("safe to bytes", func(b *testing.B) { + s := "hello ekit! hello golang! this is test benchmark" + for i := 0; i < b.N; i++ { + _ = []byte(s) + } + }) + + b.Run("unsafe to bytes", func(b *testing.B) { + s := "hello ekit! hello golang! this is test benchmark" + for i := 0; i < b.N; i++ { + _ = UnsafeToBytes(s) + } + }) +} + +func Benchmark_UnsafeToString(b *testing.B) { + b.Run("safe to string", func(b *testing.B) { + s := []byte("hello ekit! hello golang! this is test benchmark") + for i := 0; i < b.N; i++ { + _ = string(s) + } + }) + + b.Run("unsafe to string", func(b *testing.B) { + s := []byte("hello ekit! hello golang! this is test benchmark") + for i := 0; i < b.N; i++ { + _ = UnsafeToString(s) + } + }) +} diff --git a/stringx/stringx_benchmark b/stringx/stringx_benchmark new file mode 100644 index 00000000..b5e83bcc --- /dev/null +++ b/stringx/stringx_benchmark @@ -0,0 +1,10 @@ +goos: darwin +goarch: amd64 +pkg: github.com/ecodeclub/ekit/stringx +cpu: Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz +Benchmark_UnsafeToBytes/safe_to_bytes-8 39721614 29.60 ns/op 48 B/op 1 allocs/op +Benchmark_UnsafeToBytes/unsafe_to_bytes-8 1000000000 0.2805 ns/op 0 B/op 0 allocs/op +Benchmark_UnsafeToString/safe_to_string-8 45207981 26.77 ns/op 48 B/op 1 allocs/op +Benchmark_UnsafeToString/unsafe_to_string-8 1000000000 0.2842 ns/op 0 B/op 0 allocs/op +PASS +ok github.com/ecodeclub/ekit/stringx 4.780s diff --git a/syncx/cond.go b/syncx/cond.go new file mode 100644 index 00000000..0f13e62f --- /dev/null +++ b/syncx/cond.go @@ -0,0 +1,265 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" +) + +// Cond 实现了一个条件变量,是等待或宣布一个事件发生的goroutines的汇合点。 +// +// 在改变条件和调用Wait方法的时候,Cond 关联的锁对象 L (*Mutex 或者 *RWMutex)必须被加锁, +// +// 在Go内存模型的术语中,Cond 保证 Broadcast或Signal的调用 同步于 因此而解除阻塞的 Wait 之前。 +// +// 绝大多数简单用例, 最好使用 channels 而不是 Cond +// (Broadcast 对应于关闭一个 channel, Signal 对应于给一个 channel 发送消息). +type Cond struct { + noCopy noCopy + // L 在观察或改变条件时被加锁 + L sync.Locker + notifyList *notifyList + // 用于指向自身的指针,可以用于检测是否被复制使用 + checker unsafe.Pointer + // 用于初始化notifyList + once sync.Once +} + +// NewCond 返回 关联了 l 的新 Cond . +func NewCond(l sync.Locker) *Cond { + return &Cond{L: l} +} + +// Wait 自动解锁 c.L 并挂起当前调用的 goroutine. 在恢复执行之后 Wait 在返回前将加 c.L 锁成功. +// 和其它系统不一样, 除非调用 Broadcast 或 Signal 或者 ctx 超时了,否则 Wait 不会返回. +// +// 成功唤醒时, 返回 nil. 超时失败时, 返回ctx.Err(). +// 如果 ctx 超时了, Wait 可能依旧执行成功返回 nil. +// +// 在 Wait 第一次继续执行时,因为 c.L 没有加锁, 当 Wait 返回的时候,调用者通常不能假设条件是真的 +// 相反, caller 应该在循环中调用 Wait: +// +// c.L.Lock() +// for !condition() { +// if err := c.Wait(ctx); err != nil { +// // 超时唤醒了,并不是被正常唤醒的,可以做一些超时的处理 +// } +// } +// ... condition 满足了,do work ... +// c.L.Unlock() +func (c *Cond) Wait(ctx context.Context) error { + c.checkCopy() + c.checkFirstUse() + t := c.notifyList.add() // 解锁前,将等待的对象放入链表中 + c.L.Unlock() // 一定是在等待对象放入链表后再解锁,避免刚解锁就发生协程切换,执行了signal后,再换回来导致永远阻塞 + defer c.L.Lock() + return c.notifyList.wait(ctx, t) +} + +// Signal 唤醒一个等待在 c 上的goroutine. +// +// 调用时,caller 可以持有也可以不持有 c.L 锁 +// +// Signal() 不影响 goroutine 调度的优先级; 如果其它的 goroutines +// 尝试着锁定 c.L, 它们可能在 "waiting" goroutine 之前被唤醒. +func (c *Cond) Signal() { + c.checkCopy() + c.checkFirstUse() + c.notifyList.notifyOne() +} + +// Broadcast 唤醒所有等待在 c 上的goroutine. +// +// 调用时,caller 可以持有也可以不持有 c.L 锁 +func (c *Cond) Broadcast() { + c.checkCopy() + c.checkFirstUse() + c.notifyList.notifyAll() +} + +// checkCopy 检查是否被拷贝使用 +func (c *Cond) checkCopy() { + // 判断checker保存的指针是否等于当前的指针(初始化时,并没有初始化checker的值,所以也会出现不相等) + if c.checker != unsafe.Pointer(c) && + // 由于初次初始化时,c.checker为0值,所以顺便进行一次原子替换,辅助初始化 + !atomic.CompareAndSwapPointer(&c.checker, nil, unsafe.Pointer(c)) && + // 再度检查checker保留指针是否等于当前指针 + c.checker != unsafe.Pointer(c) { + panic("syncx.Cond is copied") + } +} + +// checkFirstUse 用于初始化notifyList +func (c *Cond) checkFirstUse() { + c.once.Do(func() { + if c.notifyList == nil { + c.notifyList = newNotifyList() + } + }) +} + +// notifyList 是一个简单的 runtime_notifyList 实现,但增强了 wait 方法 +type notifyList struct { + mu sync.Mutex + list *chanList +} + +func newNotifyList() *notifyList { + return ¬ifyList{ + mu: sync.Mutex{}, + list: newChanList(), + } +} + +func (l *notifyList) add() *node { + l.mu.Lock() + defer l.mu.Unlock() + el := l.list.alloc() + l.list.pushBack(el) + return el +} + +func (l *notifyList) wait(ctx context.Context, elem *node) error { + ch := elem.Value + // 回收ch,超时时,因为没有被使用过,直接复用 + // 正常唤醒时,由于被放入了一条消息,但被取出来了一次,所以elem中的ch可以重复使用 + // 由于ch是挂在elem上的,所以elem在ch被回收之前,不可以被错误回收,所以必须在这里进行回收 + defer l.list.free(elem) + select { // 由于会随机选择一条,在超时和通知同时存在的话,如果通知先行,则没有影响,如果超时的同时,又来了通知 + case <-ctx.Done(): // 进了超时分支 + l.mu.Lock() + defer l.mu.Unlock() + select { + // double check: 检查是否在加锁前,刚好被正常通知了, + // 如果取到数据,代表收到了信号了,ch也因为被取了一次消息,意味着可以再次复用 + // 转移信号到下一个 + // 如果有下一个等待的,就唤醒它 + case <-ch: + if l.list.len() != 0 { + l.notifyNext() + } + // 如果取不到数据,代表不可能被正常唤醒了,ch也意味着没有被使用,可以从队列移除等待对象 + default: + l.list.remove(elem) + } + return ctx.Err() + case <-ch: // 如果取到数据,代表被正常唤醒了,ch也因为被取了一次消息,意味着可以再次复用 + return nil + } +} + +func (l *notifyList) notifyOne() { + l.mu.Lock() + defer l.mu.Unlock() + if l.list.len() == 0 { + return + } + l.notifyNext() +} + +func (l *notifyList) notifyNext() { + front := l.list.front() + ch := front.Value + l.list.remove(front) + ch <- struct{}{} +} + +func (l *notifyList) notifyAll() { + l.mu.Lock() + defer l.mu.Unlock() + for l.list.len() != 0 { + l.notifyNext() + } +} + +// node 保存chan的链表元素 +type node struct { + prev *node + next *node + Value chan struct{} +} + +// chanList 用于存放保存channel的一个双链表, 带复用元素的功能 +type chanList struct { + // 哨兵元素,方便处理元素个数为0的情况 + sentinel *node + size int + pool *sync.Pool +} + +func newChanList() *chanList { + sentinel := &node{} + sentinel.prev = sentinel + sentinel.next = sentinel + return &chanList{ + sentinel: sentinel, + size: 0, + pool: &sync.Pool{ + New: func() any { + return &node{ + Value: make(chan struct{}, 1), + } + }, + }, + } +} + +// len 获取链表长度 +func (l *chanList) len() int { + return l.size +} + +// front 获取队首元素 +func (l *chanList) front() *node { + return l.sentinel.next +} + +// alloc 申请新的元素,包含复用的chan +func (l *chanList) alloc() *node { + elem := l.pool.Get().(*node) + return elem +} + +// pushBack 追加元素到队尾 +func (l *chanList) pushBack(elem *node) { + elem.next = l.sentinel + elem.prev = l.sentinel.prev + l.sentinel.prev.next = elem + l.sentinel.prev = elem + l.size++ +} + +// remove 元素移除时,还不能回收该元素,避免元素上的chan被错误覆盖 +func (l *chanList) remove(elem *node) { + elem.prev.next = elem.next + elem.next.prev = elem.prev + elem.prev = nil + elem.next = nil + l.size-- +} + +// free 回收该元素,用于下次alloc获取时复用,避免再次分配 +func (l *chanList) free(elem *node) { + l.pool.Put(elem) +} + +// 用于静态代码检查复制的问题 +type noCopy struct{} + +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/syncx/cond_sdk_test.go b/syncx/cond_sdk_test.go new file mode 100644 index 00000000..9fb279ab --- /dev/null +++ b/syncx/cond_sdk_test.go @@ -0,0 +1,335 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file was blatantly stolen from https://cs.opensource.google/go/go/+/refs/tags/go1.19.3:src/sync/cond_test.go. + +package syncx + +import ( + "context" + "reflect" + "runtime" + "sync" + "testing" + "time" +) + +func TestCondSignal(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 2 + running := make(chan bool, n) + awake := make(chan bool, n) + for i := 0; i < n; i++ { + go func() { + m.Lock() + running <- true + _ = c.Wait(context.Background()) + awake <- true + m.Unlock() + }() + } + for i := 0; i < n; i++ { + <-running // Wait for everyone to run. + } + for n > 0 { + select { + case <-awake: + t.Fatal("goroutine not asleep") + default: + } + m.Lock() + c.Signal() + m.Unlock() + <-awake // Will deadlock if no goroutine wakes up + select { + case <-awake: + t.Fatal("too many goroutines awake") + default: + } + n-- + } + c.Signal() +} + +func TestCondSignalGenerations(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 100 + running := make(chan bool, n) + awake := make(chan int, n) + for i := 0; i < n; i++ { + go func(i int) { + m.Lock() + running <- true + _ = c.Wait(context.Background()) + awake <- i + m.Unlock() + }(i) + if i > 0 { + a := <-awake + if a != i-1 { + t.Fatalf("wrong goroutine woke up: want %d, got %d", i-1, a) + } + } + <-running + m.Lock() + c.Signal() + m.Unlock() + } +} + +func TestCondBroadcast(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 200 + running := make(chan int, n) + awake := make(chan int, n) + exit := false + for i := 0; i < n; i++ { + go func(g int) { + m.Lock() + for !exit { + running <- g + _ = c.Wait(context.Background()) + awake <- g + } + m.Unlock() + }(i) + } + for i := 0; i < n; i++ { + for i := 0; i < n; i++ { + <-running // Will deadlock unless n are running. + } + if i == n-1 { + m.Lock() + exit = true + m.Unlock() + } + select { + case <-awake: + t.Fatal("goroutine not asleep") + default: + } + m.Lock() + c.Broadcast() + m.Unlock() + seen := make([]bool, n) + for i := 0; i < n; i++ { + g := <-awake + if seen[g] { + t.Fatal("goroutine woke up twice") + } + seen[g] = true + } + } + select { + case <-running: + t.Fatal("goroutine did not exit") + default: + } + c.Broadcast() +} + +func TestRace(t *testing.T) { + x := 0 + c := NewCond(&sync.Mutex{}) + done := make(chan bool) + go func() { + c.L.Lock() + x = 1 + _ = c.Wait(context.Background()) + if x != 2 { + t.Error("want 2") + } + x = 3 + c.Signal() + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 1 { + x = 2 + c.Signal() + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 2 { + _ = c.Wait(context.Background()) + if x != 3 { + t.Error("want 3") + } + break + } + if x == 3 { + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + <-done + <-done + <-done +} + +func TestCondSignalStealing(t *testing.T) { + for iters := 0; iters < 1000; iters++ { + var m sync.Mutex + cond := NewCond(&m) + + // Start a waiter. + ch := make(chan struct{}) + go func() { + m.Lock() + ch <- struct{}{} + _ = cond.Wait(context.Background()) + m.Unlock() + + ch <- struct{}{} + }() + + <-ch + m.Lock() + done := false + m.Unlock() + + // We know that the waiter is in the cond.Wait() call because we + // synchronized with it, then acquired/released the mutex it was + // holding when we synchronized. + // + // Start two goroutines that will race: one will broadcast on + // the cond var, the other will wait on it. + // + // The new waiter may or may not get notified, but the first one + // has to be notified. + + go func() { + cond.Broadcast() + }() + + go func() { + m.Lock() + for !done { + _ = cond.Wait(context.Background()) + } + m.Unlock() + }() + + // Check that the first waiter does get signaled. + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatalf("First waiter didn't get broadcast.") + } + + // Release the second waiter in case it didn't get the + // broadcast. + m.Lock() + done = true + m.Unlock() + cond.Broadcast() + } +} + +func TestCondCopy(t *testing.T) { + defer func() { + err := recover() + if err == nil || err.(string) != "syncx.Cond is copied" { + t.Fatalf("got %v, expect syncx.Cond is copied", err) + } + }() + c := Cond{L: &sync.Mutex{}} + c.Signal() + var c2 Cond + reflect.ValueOf(&c2).Elem().Set(reflect.ValueOf(&c).Elem()) // c2 := c, hidden from vet + c2.Signal() +} + +func BenchmarkCond1(b *testing.B) { + benchmarkCond(b, 1) +} + +func BenchmarkCond2(b *testing.B) { + benchmarkCond(b, 2) +} + +func BenchmarkCond4(b *testing.B) { + benchmarkCond(b, 4) +} + +func BenchmarkCond8(b *testing.B) { + benchmarkCond(b, 8) +} + +func BenchmarkCond16(b *testing.B) { + benchmarkCond(b, 16) +} + +// BenchmarkCond32 test.bench: 100000x 31851 ns/op 1539 B/op 32 allocs/op +func BenchmarkCond32(b *testing.B) { + benchmarkCond(b, 32) +} + +func benchmarkCond(b *testing.B, waiters int) { + c := NewCond(&sync.Mutex{}) + done := make(chan bool) + id := 0 + + for routine := 0; routine < waiters+1; routine++ { + go func() { + for i := 0; i < b.N; i++ { + c.L.Lock() + if id == -1 { + c.L.Unlock() + break + } + id++ + if id == waiters+1 { + id = 0 + c.Broadcast() + } else { + _ = c.Wait(context.Background()) + } + c.L.Unlock() + } + c.L.Lock() + id = -1 + c.Broadcast() + c.L.Unlock() + done <- true + }() + } + for routine := 0; routine < waiters+1; routine++ { + <-done + } +} diff --git a/syncx/cond_test.go b/syncx/cond_test.go new file mode 100644 index 00000000..fb889d74 --- /dev/null +++ b/syncx/cond_test.go @@ -0,0 +1,297 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "context" + "math/rand" + "reflect" + "sync" + "testing" + "time" +) + +func TestCond_Broadcast(t *testing.T) { + + cond := NewCond(&sync.Mutex{}) + + type status struct { + i int + err error + } + + sleepDuration := time.Millisecond * 100 + + var n = 100 + running := make(chan int, n) + awake := make(chan status, n) + waitSeqs := make([]int, n) + normalAwakeSeqs := make([]int, 0, n) + timeoutAwakeSeqs := make([]int, 0, n) + minTimeoutCnt := 0 + minNormalCnt := 0 + seen := make(map[int]bool, n) + for i := 0; i < n; i++ { + duration := time.Millisecond * 50 * time.Duration(rand.Int()%4+1) + if duration < sleepDuration*9/10 { + minTimeoutCnt++ + } else if duration > sleepDuration*11/10 { + minNormalCnt++ + } + go func(i int) { + cond.L.Lock() + + ctx, cancelFunc := context.WithTimeout(context.Background(), duration) + defer cancelFunc() + running <- i + err := cond.Wait(ctx) + awake <- status{ + i: i, + err: err, + } + cond.L.Unlock() + }(i) + } + for i := 0; i < n; i++ { + waitSeqs[i] = <-running + } + + time.Sleep(100 * time.Millisecond) + + cond.L.Lock() + cond.Broadcast() + cond.L.Unlock() + + for i := 0; i < n; i++ { + stat := <-awake + if seen[stat.i] { + t.Fatal("goroutine woke up twice") + } else { + seen[stat.i] = true + } + if stat.err != nil { + timeoutAwakeSeqs = append(timeoutAwakeSeqs, stat.i) + } else { + normalAwakeSeqs = append(normalAwakeSeqs, stat.i) + } + } + + if len(normalAwakeSeqs) < minNormalCnt { + t.Fatal("goroutine woke up with timeout") + } + + if len(timeoutAwakeSeqs) < minTimeoutCnt { + t.Fatal("goroutine woke up with normally") + } +} + +func TestCond_Signal(t *testing.T) { + + cond := NewCond(&sync.Mutex{}) + + type status struct { + i int + err error + } + + sleepDuration := time.Millisecond * 100 + + var n = 100 + running := make(chan int, n) + awake := make(chan status, n) + waitSeqs := make([]int, n) + normalAwakeSeqs := make([]int, 0, n) + timeoutAwakeSeqs := make([]int, 0, n) + minTimeoutCnt := 0 + minNormalCnt := 0 + seen := make(map[int]bool, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + duration := time.Millisecond * 50 * time.Duration(rand.Int()%4+1) + if duration < sleepDuration*9/10 { + minTimeoutCnt++ + } else if duration > sleepDuration*11/10 { + minNormalCnt++ + } + go func(i int) { + cond.L.Lock() + + ctx, cancelFunc := context.WithTimeout(context.Background(), duration) + defer cancelFunc() + running <- i + err := cond.Wait(ctx) + awake <- status{ + i: i, + err: err, + } + cond.L.Unlock() + wg.Done() + }(i) + } + for i := 0; i < n; i++ { + waitSeqs[i] = <-running + } + + go func() { + wg.Wait() + close(awake) + }() + + time.Sleep(100 * time.Millisecond) + + for i := 0; i < n; i++ { + cond.L.Lock() + cond.Signal() + cond.L.Unlock() + for { + stat, ok := <-awake + if !ok { + break + } + if seen[stat.i] { + t.Fatal("goroutine woke up twice") + } else { + seen[stat.i] = true + } + if stat.err != nil { + timeoutAwakeSeqs = append(timeoutAwakeSeqs, stat.i) + } else { + normalAwakeSeqs = append(normalAwakeSeqs, stat.i) + break + } + } + + } + + if len(normalAwakeSeqs) < minNormalCnt { + t.Fatal("goroutine woke up with timeout") + } + + if len(timeoutAwakeSeqs) < minTimeoutCnt { + t.Fatal("goroutine woke up with normally") + } + // 测试singnal唤醒的顺序问题 + if !isInOrder(normalAwakeSeqs, waitSeqs) { + t.Fatal("goroutine woke up not in order") + } + // 超时唤醒的肯定是乱序的,没有好办法测试顺序 + //if !isInOrder(timeoutAwakeSeqs, waitSeqs) { + // t.Fatal("goroutine woke up not in order") + //} +} + +func isInOrder(partial []int, source []int) bool { + + j := 0 + + for i := 0; i < len(partial); i++ { + matched := false + for j < len(source) { + if partial[i] == source[j] { + j++ + matched = true + break + } + j++ + continue + } + if !matched { + return false + } + } + + return true +} + +func Test_InOrder(t *testing.T) { + testcases := []struct { + name string + partial []int + source []int + want bool + }{ + {"", []int{1}, []int{1}, true}, + {"", []int{1, 3, 4}, []int{1, 2, 3, 4}, true}, + {"", []int{1, 3}, []int{1, 2, 3, 4}, true}, + {"", []int{1, 3, 2}, []int{1, 2, 3, 4}, false}, + {"", []int{1, 2, 2}, []int{1, 2, 3, 4}, false}, + {"", []int{1, 2, 3}, []int{1, 3, 2, 4}, false}, + {"", []int{1, 2, 4}, []int{1, 3, 2, 4}, true}, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + if target := isInOrder(tt.partial, tt.source); target != tt.want { + t.Errorf("get %v, want %v", target, tt.want) + } + }) + } +} + +// TestChanList 测试有序,和清空后重复使用是否有问题 +func TestChanList(t *testing.T) { + + l := newChanList() + + testcases := []struct { + name string + num int + }{ + {"", 5}, + {"", 3}, + {"", 10}, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(tt *testing.T) { + inputNodes := make([]*node, 0, testcase.num) + inputChans := make([]chan struct{}, 0, testcase.num) + for i := 0; i < testcase.num; i++ { + ele := l.alloc() + inputNodes = append(inputNodes, ele) + inputChans = append(inputChans, ele.Value) + l.pushBack(ele) + } + if length := l.len(); length != testcase.num { + t.Errorf("list.len() = %v, want %v", length, testcase.num) + } + outNodes := make([]*node, 0, testcase.num) + outChans := make([]chan struct{}, 0, testcase.num) + for l.len() != 0 { + front := l.front() + outNodes = append(outNodes, front) + outChans = append(outChans, front.Value) + l.remove(front) + } + if !reflect.DeepEqual(outChans, inputChans) { + t.Errorf("chan list is %v, but got %v", inputChans, outChans) + } + if !reflect.DeepEqual(outNodes, inputNodes) { + t.Errorf("element list is %v, but got %v", inputNodes, outNodes) + } + }) + } +} + +// BenchmarkChanList 测试有无内存分配增加的情况 +func BenchmarkChanList(b *testing.B) { + l := newChanList() + for i := 0; i < b.N; i++ { + elem := l.alloc() + l.pushBack(elem) + l.remove(elem) + } +} diff --git a/syncx/map.go b/syncx/map.go index 9ad07def..2c958784 100644 --- a/syncx/map.go +++ b/syncx/map.go @@ -41,6 +41,7 @@ func (m *Map[K, V]) Store(key K, value V) { } // LoadOrStore 加载或者存储一个键值对 +// true 代表是加载的,false 代表执行了 store func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { var anyVal any anyVal, loaded = m.m.LoadOrStore(key, value) @@ -50,6 +51,23 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { return } +// LoadOrStoreFunc 是一个优化,也就是使用该方法能够避免无意义的创建实例。 +// 如果你的初始化过程非常消耗资源,那么使用这个方法是有价值的。 +// 它的代价就是 Key 不存在的时候会多一次 Load 调用。 +// 当 fn 返回 error 的时候,LoadOrStoreFunc 也会返回 error。 +func (m *Map[K, V]) LoadOrStoreFunc(key K, fn func() (V, error)) (actual V, loaded bool, err error) { + val, ok := m.Load(key) + if ok { + return val, true, nil + } + val, err = fn() + if err != nil { + return + } + actual, loaded = m.LoadOrStore(key, val) + return +} + // LoadAndDelete 加载并且删除一个键值对 func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { var anyVal any diff --git a/syncx/map_test.go b/syncx/map_test.go index f8d8a20c..95634e34 100644 --- a/syncx/map_test.go +++ b/syncx/map_test.go @@ -15,6 +15,7 @@ package syncx import ( + "errors" "fmt" "sync" "testing" @@ -56,101 +57,213 @@ func TestMap_Load(t *testing.T) { }, } var mu Map[string, *User] - mu.Store("found", &User{Name: "found"}) - mu.Store("found but empty", &User{}) + mu.Store("found", testCases[0].wantVal) + mu.Store("found but empty", testCases[1].wantVal) mu.Store("found but nil", nil) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { val, ok := mu.Load(tc.key) assert.Equal(t, tc.wantOk, ok) - assert.Equal(t, tc.wantVal, val) + assert.Same(t, tc.wantVal, val) }) } } func TestMap_LoadOrStore(t *testing.T) { - var m = Map[string, *User]{} - val, loaded := m.LoadOrStore("Tom", &User{Name: "Tom"}) - assert.False(t, loaded) - assert.Equal(t, &User{Name: "Tom"}, val) - val, loaded = m.LoadOrStore("Tom", &User{Name: "Tom-copy"}) - assert.True(t, loaded) - assert.Equal(t, &User{Name: "Tom"}, val) + t.Run("store non-nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + val, loaded := m.LoadOrStore(user.Name, user) + assert.False(t, loaded) + assert.Same(t, user, val) + }) - val, loaded = m.LoadOrStore("Jerry", nil) - assert.False(t, loaded) - assert.Nil(t, val) + t.Run("load non-nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + val, loaded := m.LoadOrStore(user.Name, user) + assert.False(t, loaded) + assert.Same(t, user, val) - val, loaded = m.LoadOrStore("Jerry", &User{Name: "Jerry"}) - assert.True(t, loaded) - assert.Nil(t, val) + val, loaded = m.LoadOrStore("Tom", &User{Name: "Tom-copy"}) + + assert.True(t, loaded) + assert.Same(t, user, val) + }) + + t.Run("store nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Jerry"} + val, loaded := m.LoadOrStore(user.Name, nil) + assert.False(t, loaded) + assert.Nil(t, val) + }) + + t.Run("load nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Jerry"} + val, loaded := m.LoadOrStore(user.Name, nil) + assert.False(t, loaded) + assert.Nil(t, val) + + val, loaded = m.LoadOrStore(user.Name, user) + + assert.True(t, loaded) + assert.Nil(t, val) + }) +} + +func TestMap_LoadOrStoreFunc(t *testing.T) { + + t.Run("store non-nil value returned by func", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + + val, loaded, err := m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return user, nil + }) + + assert.NoError(t, err) + assert.False(t, loaded) + assert.Same(t, user, val) + }) + + t.Run("load non-nil value returned by func", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + val, loaded, err := m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return user, nil + }) + assert.NoError(t, err) + assert.False(t, loaded) + assert.Same(t, user, val) + + val, loaded, err = m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return &User{Name: "Tom"}, nil + }) + + assert.NoError(t, err) + assert.True(t, loaded) + assert.Same(t, user, val) + }) + + t.Run("store nil value returned by func", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + + val, loaded, err := m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return nil, nil + }) + + assert.NoError(t, err) + assert.False(t, loaded) + assert.Nil(t, val) + }) + + t.Run("load nil value returned by func", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + val, loaded, err := m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return nil, nil + }) + assert.NoError(t, err) + assert.False(t, loaded) + assert.Nil(t, val) + + val, loaded, err = m.LoadOrStoreFunc(user.Name, func() (*User, error) { + return nil, nil + }) + + assert.NoError(t, err) + assert.True(t, loaded) + assert.Nil(t, val) + }) + + t.Run("got error returned by func", func(t *testing.T) { + m := Map[string, *User]{} + val, loaded, err := m.LoadOrStoreFunc("Jerry", func() (*User, error) { + return nil, errors.New("初始话失败") + }) + assert.Equal(t, err, errors.New("初始话失败")) + assert.False(t, loaded) + assert.Equal(t, (*User)(nil), val) + }) } func TestMap_LoadAndDelete(t *testing.T) { - var m = Map[string, *User]{} - m.Store("Tom", nil) - val, loaded := m.LoadAndDelete("Tom") - assert.True(t, loaded) - assert.Nil(t, val) - val, loaded = m.LoadAndDelete("Tom") - assert.False(t, loaded) - assert.Nil(t, val) + t.Run("non-nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Jerry"} + m.Store("Jerry", user) - m.Store("Jerry", &User{Name: "Jerry"}) - val, loaded = m.LoadAndDelete("Jerry") - assert.True(t, loaded) - assert.Equal(t, &User{Name: "Jerry"}, val) + val, loaded := m.LoadAndDelete(user.Name) + assert.True(t, loaded) + assert.Same(t, user, val) - val, loaded = m.LoadAndDelete("Jerry") - assert.False(t, loaded) - assert.Nil(t, val) + val, loaded = m.LoadAndDelete(user.Name) + assert.False(t, loaded) + assert.Nil(t, val) + }) + + t.Run("nil value", func(t *testing.T) { + m, user := Map[string, *User]{}, &User{Name: "Tom"} + m.Store(user.Name, nil) + + val, loaded := m.LoadAndDelete(user.Name) + assert.True(t, loaded) + assert.Nil(t, val) + + val, loaded = m.LoadAndDelete(user.Name) + assert.False(t, loaded) + assert.Nil(t, val) + }) } func TestMap_Delete(t *testing.T) { - var m = Map[string, *User]{} - m.Store("Tom", &User{Name: "Tom"}) - val, ok := m.Load("Tom") + m, user := Map[string, *User]{}, &User{Name: "Tom"} + m.Store(user.Name, user) + val, ok := m.Load(user.Name) assert.True(t, ok) - assert.Equal(t, &User{Name: "Tom"}, val) - m.Delete("Tom") - val, ok = m.Load("Tom") + assert.Same(t, user, val) + + m.Delete(user.Name) + + val, ok = m.Load(user.Name) assert.False(t, ok) assert.Nil(t, val) } func TestMap_Range(t *testing.T) { - var m = Map[string, *User]{} - m.Store("Tom", &User{Name: "Tom"}) - m.Store("Jerry", &User{Name: "Jerry"}) - m.Store("nil", nil) - shadow := make(map[string]*User, 3) - m.Range(func(key string, val *User) bool { - shadow[key] = val - return true + t.Run("non-pointer type key", func(t *testing.T) { + m, tom, jerry := Map[string, *User]{}, &User{Name: "Tom"}, &User{Name: "Jerry"} + var zero *User + m.Store(tom.Name, tom) + m.Store(jerry.Name, jerry) + m.Store("zero", zero) + m.Store("nil", nil) + + shadow := make(map[string]*User, 4) + m.Range(func(key string, val *User) bool { + shadow[key] = val + return true + }) + + assert.Same(t, tom, shadow[tom.Name]) + assert.Same(t, jerry, shadow[jerry.Name]) + assert.Same(t, zero, shadow["zero"]) + assert.Same(t, (*User)(nil), shadow["nil"]) }) - assert.Equal(t, map[string]*User{ - "Tom": {Name: "Tom"}, - "Jerry": {Name: "Jerry"}, - "nil": nil, - }, shadow) - - var ptrKeyMap Map[*User, string] - key1 := &User{Name: "Tom"} - var key2 *User - ptrKeyMap.Store(key1, "Tom") - ptrKeyMap.Store(key2, "nil") - ptrShadow := make(map[*User]string, 2) - ptrKeyMap.Range(func(key *User, val string) bool { - ptrShadow[key] = val - return true + + t.Run("pointer type key", func(t *testing.T) { + m, tom := Map[*User, string]{}, &User{Name: "Tom"} + var zero *User + m.Store(tom, "Tom") + m.Store(zero, "nil") + + shadow := make(map[*User]string, 2) + m.Range(func(key *User, val string) bool { + shadow[key] = val + return true + }) + + assert.Equal(t, shadow[tom], tom.Name) + assert.Equal(t, shadow[zero], "nil") + assert.Equal(t, shadow[nil], "nil") }) - assert.Equal(t, map[*User]string{ - key1: "Tom", - nil: "nil", - }, ptrShadow) } func ExampleMap_LoadAndDelete() { @@ -204,6 +317,54 @@ func ExampleMap_LoadOrStore() { // 加载旧值 } +func ExampleMap_LoadOrStoreFunc() { + var m = Map[string, *User]{} + _, loaded, _ := m.LoadOrStoreFunc("Tom", func() (*User, error) { + return &User{Name: "Tom"}, nil + }) + // 执行存储 + if !loaded { + fmt.Println("设置了新值 Tom") + } + + _, loaded, _ = m.LoadOrStoreFunc("Tom", func() (*User, error) { + return &User{Name: "Tom-copy"}, nil + }) + // Tom 这个 key 已经存在,执行加载 + if loaded { + fmt.Println("加载旧值 Tom") + } + + _, loaded, _ = m.LoadOrStoreFunc("Jerry", func() (*User, error) { + return nil, nil + }) + // 执行存储,注意值是 nil + if !loaded { + fmt.Println("设置了新值 nil") + } + val, loaded, _ := m.LoadOrStoreFunc("Jerry", func() (*User, error) { + return &User{Name: "Jerry"}, nil + }) + // Jerry 这个 key 已经存在,执行加载,于是把原本的 nil 加载出来 + if loaded { + fmt.Printf("加载旧值 %v\n", val) + } + + _, _, err := m.LoadOrStoreFunc("Kitty", func() (*User, error) { + return nil, errors.New("初始化失败") + }) + if err != nil { + fmt.Println(err.Error()) + } + + // Output: + // 设置了新值 Tom + // 加载旧值 Tom + // 设置了新值 nil + // 加载旧值 + // 初始化失败 +} + func ExampleMap_Range() { var m Map[string, int] m.Store("Tom", 18) diff --git a/value.go b/value.go index 96beed0e..a2e89bb4 100644 --- a/value.go +++ b/value.go @@ -15,7 +15,10 @@ package ekit import ( + "errors" + "fmt" "reflect" + "strconv" "github.com/ecodeclub/ekit/internal/errs" ) @@ -38,6 +41,20 @@ func (av AnyValue) Int() (int, error) { return val, nil } +func (av AnyValue) AsInt() (int, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case int: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + return int(res), err + } + return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) +} + // IntOrDefault 返回 int 数据,或者默认值 func (av AnyValue) IntOrDefault(def int) int { val, err := av.Int() @@ -59,6 +76,20 @@ func (av AnyValue) Uint() (uint, error) { return val, nil } +func (av AnyValue) AsUint() (uint, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case uint: + return v, nil + case string: + res, err := strconv.ParseUint(v, 10, 64) + return uint(res), err + } + return 0, errs.NewErrInvalidType("uint", reflect.TypeOf(av.Val).String()) +} + // UintOrDefault 返回 uint 数据,或者默认值 func (av AnyValue) UintOrDefault(def uint) uint { val, err := av.Uint() @@ -68,6 +99,142 @@ func (av AnyValue) UintOrDefault(def uint) uint { return val } +func (av AnyValue) Int8() (int8, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(int8) + if !ok { + return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +func (av AnyValue) AsInt8() (int8, error) { + if av.Err != nil { + return 0, av.Err + } + + switch v := av.Val.(type) { + case int8: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + return int8(res), err + } + return 0, errs.NewErrInvalidType("int8", reflect.TypeOf(av.Val).String()) +} + +func (av AnyValue) Int8OrDefault(def int8) int8 { + val, err := av.Int8() + if err != nil { + return def + } + return val +} + +func (av AnyValue) Uint8() (uint8, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(uint8) + if !ok { + return 0, errs.NewErrInvalidType("uint8", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +func (av AnyValue) AsUint8() (uint8, error) { + if av.Err != nil { + return 0, av.Err + } + + switch v := av.Val.(type) { + case uint8: + return v, nil + case string: + res, err := strconv.ParseUint(v, 10, 8) + return uint8(res), err + } + return 0, errs.NewErrInvalidType("uint8", reflect.TypeOf(av.Val).String()) +} + +func (av AnyValue) Uint8OrDefault(def uint8) uint8 { + val, err := av.Uint8() + if err != nil { + return def + } + return val +} + +func (av AnyValue) Int16() (int16, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(int16) + if !ok { + return 0, errs.NewErrInvalidType("int16", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +func (av AnyValue) AsInt16() (int16, error) { + if av.Err != nil { + return 0, av.Err + } + + switch v := av.Val.(type) { + case int16: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 16) + return int16(res), err + } + return 0, errs.NewErrInvalidType("int16", reflect.TypeOf(av.Val).String()) +} + +func (av AnyValue) Int16OrDefault(def int16) int16 { + val, err := av.Int16() + if err != nil { + return def + } + return val +} + +func (av AnyValue) Uint16() (uint16, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(uint16) + if !ok { + return 0, errs.NewErrInvalidType("uint16", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +func (av AnyValue) AsUint16() (uint16, error) { + if av.Err != nil { + return 0, av.Err + } + + switch v := av.Val.(type) { + case uint16: + return v, nil + case string: + res, err := strconv.ParseUint(v, 10, 16) + return uint16(res), err + } + return 0, errs.NewErrInvalidType("uint16", reflect.TypeOf(av.Val).String()) +} + +func (av AnyValue) Uint16OrDefault(def uint16) uint16 { + val, err := av.Uint16() + if err != nil { + return def + } + return val +} + // Int32 返回 int32 数据 func (av AnyValue) Int32() (int32, error) { if av.Err != nil { @@ -80,6 +247,20 @@ func (av AnyValue) Int32() (int32, error) { return val, nil } +func (av AnyValue) AsInt32() (int32, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case int32: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 32) + return int32(res), err + } + return 0, errs.NewErrInvalidType("int32", reflect.TypeOf(av.Val).String()) +} + // Int32OrDefault 返回 int32 数据,或者默认值 func (av AnyValue) Int32OrDefault(def int32) int32 { val, err := av.Int32() @@ -101,6 +282,20 @@ func (av AnyValue) Uint32() (uint32, error) { return val, nil } +func (av AnyValue) AsUint32() (uint32, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case uint32: + return v, nil + case string: + res, err := strconv.ParseUint(v, 10, 32) + return uint32(res), err + } + return 0, errs.NewErrInvalidType("uint32", reflect.TypeOf(av.Val).String()) +} + // Uint32OrDefault 返回 uint32 数据,或者默认值 func (av AnyValue) Uint32OrDefault(def uint32) uint32 { val, err := av.Uint32() @@ -122,6 +317,19 @@ func (av AnyValue) Int64() (int64, error) { return val, nil } +func (av AnyValue) AsInt64() (int64, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case int64: + return v, nil + case string: + return strconv.ParseInt(v, 10, 64) + } + return 0, errs.NewErrInvalidType("int64", reflect.TypeOf(av.Val).String()) +} + // Int64OrDefault 返回 int64 数据,或者默认值 func (av AnyValue) Int64OrDefault(def int64) int64 { val, err := av.Int64() @@ -143,6 +351,19 @@ func (av AnyValue) Uint64() (uint64, error) { return val, nil } +func (av AnyValue) AsUint64() (uint64, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case uint64: + return v, nil + case string: + return strconv.ParseUint(v, 10, 64) + } + return 0, errs.NewErrInvalidType("uint64", reflect.TypeOf(av.Val).String()) +} + // Uint64OrDefault 返回 uint64 数据,或者默认值 func (av AnyValue) Uint64OrDefault(def uint64) uint64 { val, err := av.Uint64() @@ -164,6 +385,20 @@ func (av AnyValue) Float32() (float32, error) { return val, nil } +func (av AnyValue) AsFloat32() (float32, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case float32: + return v, nil + case string: + res, err := strconv.ParseFloat(v, 32) + return float32(res), err + } + return 0, errs.NewErrInvalidType("float32", reflect.TypeOf(av.Val).String()) +} + // Float32OrDefault 返回 float32 数据,或者默认值 func (av AnyValue) Float32OrDefault(def float32) float32 { val, err := av.Float32() @@ -185,6 +420,19 @@ func (av AnyValue) Float64() (float64, error) { return val, nil } +func (av AnyValue) AsFloat64() (float64, error) { + if av.Err != nil { + return 0, av.Err + } + switch v := av.Val.(type) { + case float64: + return v, nil + case string: + return strconv.ParseFloat(v, 64) + } + return 0, errs.NewErrInvalidType("float64", reflect.TypeOf(av.Val).String()) +} + // Float64OrDefault 返回 float64 数据,或者默认值 func (av AnyValue) Float64OrDefault(def float64) float64 { val, err := av.Float64() @@ -206,6 +454,36 @@ func (av AnyValue) String() (string, error) { return val, nil } +func (av AnyValue) AsString() (string, error) { + if av.Err != nil { + return "", av.Err + } + + var val string + valueOf := reflect.ValueOf(av.Val) + switch valueOf.Type().Kind() { + case reflect.String: + val = valueOf.String() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val = strconv.FormatUint(valueOf.Uint(), 10) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val = strconv.FormatInt(valueOf.Int(), 10) + case reflect.Float32: + val = strconv.FormatFloat(valueOf.Float(), 'f', 10, 32) + case reflect.Float64: + val = strconv.FormatFloat(valueOf.Float(), 'f', 10, 64) + case reflect.Slice: + if valueOf.Type().Elem().Kind() != reflect.Uint8 { + return "", errs.NewErrInvalidType("[]byte", fmt.Sprintf("[]%s", valueOf.Type().Elem().Kind())) + } + val = string(valueOf.Bytes()) + default: + return "", errors.New("未兼容类型,暂时无法转换") + } + + return val, nil +} + // StringOrDefault 返回 string 数据,或者默认值 func (av AnyValue) StringOrDefault(def string) string { val, err := av.String() @@ -227,6 +505,20 @@ func (av AnyValue) Bytes() ([]byte, error) { return val, nil } +func (av AnyValue) AsBytes() ([]byte, error) { + if av.Err != nil { + return []byte{}, av.Err + } + switch v := av.Val.(type) { + case []byte: + return v, nil + case string: + return []byte(v), nil + } + + return []byte{}, errs.NewErrInvalidType("[]byte", reflect.TypeOf(av.Val).String()) +} + // BytesOrDefault 返回 []byte 数据,或者默认值 func (av AnyValue) BytesOrDefault(def []byte) []byte { val, err := av.Bytes() diff --git a/value_test.go b/value_test.go index a0cee2cb..8e9b9002 100644 --- a/value_test.go +++ b/value_test.go @@ -970,3 +970,862 @@ func TestAnyValue_BoolOrDefault(t *testing.T) { }) } } + +func TestAnyValue_Int8OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def int8 + want int8 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int8(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: int8(0), + Err: errors.New("error"), + }, + def: 1, + want: 1, + }, + { + name: "type error case:", + val: AnyValue{ + Val: true, + }, + def: 10, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, av.Int8OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Int16OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def int16 + want int16 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int16(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: int16(0), + Err: errors.New("error"), + }, + def: 1, + want: 1, + }, + { + name: "type error case:", + val: AnyValue{ + Val: true, + }, + def: 10, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, av.Int16OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Uint8OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def uint8 + want uint8 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint8(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: uint8(0), + Err: errors.New("error"), + }, + def: 1, + want: 1, + }, + { + name: "type error case:", + val: AnyValue{ + Val: true, + }, + def: 10, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, av.Uint8OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Uint16OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def uint16 + want uint16 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint16(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: uint16(0), + Err: errors.New("error"), + }, + def: 1, + want: 1, + }, + { + name: "type error case:", + val: AnyValue{ + Val: true, + }, + def: 10, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, av.Uint16OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_AsInt(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal int case:", + val: AnyValue{ + Val: int(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("int", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsInt() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsInt8(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int8 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal int case:", + val: AnyValue{ + Val: int8(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("int8", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsInt8() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsInt16(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int16 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal int16 case:", + val: AnyValue{ + Val: int16(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("int16", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsInt16() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsInt32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int32 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal int32 case:", + val: AnyValue{ + Val: int32(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("int32", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsInt32() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsInt64(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int64 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal int64 case:", + val: AnyValue{ + Val: int64(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("int64", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsInt64() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsUint(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal uint case:", + val: AnyValue{ + Val: uint(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("uint", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsUint() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsUint8(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint8 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal uint8 case:", + val: AnyValue{ + Val: uint8(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("uint8", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsUint8() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsUint16(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint16 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal uint16 case:", + val: AnyValue{ + Val: uint16(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("uint16", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsUint16() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsUint32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint32 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal uint32 case:", + val: AnyValue{ + Val: uint32(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("uint32", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsUint32() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsUint64(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint64 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1", + }, + want: 1, + }, + { + name: "normal uint64 case:", + val: AnyValue{ + Val: uint64(2), + }, + want: 2, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("uint64", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Val: "", + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsUint64() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsFloat32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want float32 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "1.01", + }, + want: 1.01, + }, + { + name: "normal float32 case:", + val: AnyValue{ + Val: float32(2.44), + }, + want: 2.44, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("float32", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsFloat32() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsFloat64(t *testing.T) { + tests := []struct { + name string + val AnyValue + want float64 + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "100.0000000000", + }, + want: 1e2, + }, + { + name: "normal float64 case:", + val: AnyValue{ + Val: float64(2.44), + }, + want: 2.44, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + err: errs.NewErrInvalidType("float64", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsFloat64() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsBytes(t *testing.T) { + tests := []struct { + name string + val AnyValue + want []byte + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "hello", + }, + want: []byte("hello"), + }, + { + name: "normal []byte case:", + val: AnyValue{ + Val: []byte{1}, + }, + want: []byte{1}, + }, + { + name: "type error case:", + val: AnyValue{ + Val: []int{1}, + }, + want: []byte{}, + err: errs.NewErrInvalidType("[]byte", "[]int"), + }, + { + name: "value exists error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + want: []byte{}, + err: errors.New("error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsBytes() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestAnyValue_AsString(t *testing.T) { + tests := []struct { + name string + val AnyValue + want string + err error + }{ + { + name: "normal string case:", + val: AnyValue{ + Val: "hello ekit", + }, + want: "hello ekit", + }, + { + name: "normal uint case:", + val: AnyValue{ + Val: uint16(1231), + }, + want: "1231", + }, + { + name: "normal int case:", + val: AnyValue{ + Val: 1, + }, + want: "1", + }, + { + name: "normal float case:", + val: AnyValue{ + Val: 1e2, + }, + want: "100.0000000000", + }, + { + name: "normal []byte case:", + val: AnyValue{ + Val: []byte{72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33}, + }, + want: "Hello, World!", + }, + { + name: "type conversion failed", + val: AnyValue{ + Val: []string{"h", "e", "llo"}, + }, + err: errs.NewErrInvalidType("[]byte", "[]string"), + }, + { + name: "type conversion failed by int", + val: AnyValue{ + Val: []int{1, 2, 3, 4, 5}, + }, + err: errs.NewErrInvalidType("[]byte", "[]int"), + }, + { + name: "unsupported type case:", + val: AnyValue{ + Val: map[string]any{ + "test": 1, + "hhh": "sss", + }, + }, + err: errors.New("未兼容类型,暂时无法转换"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.val.AsString() + assert.Equal(t, tt.want, val) + assert.Equal(t, tt.err, err) + }) + } +}