diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 3232ad6e..dc003c1e 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,2 +1,3 @@ # 开发中 -[ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6) \ No newline at end of file +[ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6) +[sql: 支持 JsonColumn](https://github.com/gotomicro/ekit/pull/7) \ No newline at end of file diff --git a/sql/json.go b/sql/json.go new file mode 100644 index 00000000..0838e34e --- /dev/null +++ b/sql/json.go @@ -0,0 +1,65 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" +) + +// JsonColumn 代表存储字段的 json 类型 +// 主要用于没有提供默认 json 类型的数据库 +// T 可以是结构体,也可以是切片或者 map +// 一切可以被 json 库所处理的类型都能被用作 T +type JsonColumn[T any] struct { + Val T + Valid bool +} + +// Value 返回一个 json 串。类型是 []byte +func (j JsonColumn[T]) Value() (driver.Value, error) { + if !j.Valid { + return nil, nil + } + return json.Marshal(j.Val) +} + +// Scan 将 src 转化为对象 +// src 的类型必须是 []byte, *[]byte, string, sql.RawBytes, *sql.RawBytes 之一 +func (j *JsonColumn[T]) Scan(src any) error { + var bs []byte + switch val := src.(type) { + case []byte: + bs = val + case *[]byte: + bs = *val + case string: + bs = []byte(val) + case sql.RawBytes: + bs = val + case *sql.RawBytes: + bs = *val + default: + return fmt.Errorf("ekit:JsonColumn.Scan 不支持 src 类型 %v", src) + } + + if err := json.Unmarshal(bs, &j.Val); err != nil { + return err + } + j.Valid = true + return nil +} diff --git a/sql/json_test.go b/sql/json_test.go new file mode 100644 index 00000000..6de12278 --- /dev/null +++ b/sql/json_test.go @@ -0,0 +1,141 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestJsonColumn_Value(t *testing.T) { + js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}} + value, err := js.Value() + assert.Nil(t, err) + assert.Equal(t, []byte(`{"Name":"Tom"}`), value) + js = JsonColumn[User]{} + value, err = js.Value() + assert.Nil(t, err) + assert.Nil(t, value) +} + +func TestJsonColumn_Scan(t *testing.T) { + testCases := []struct { + name string + src any + wantErr error + wantVal User + }{ + { + name: "nil", + wantErr: errors.New("ekit:JsonColumn.Scan 不支持 src 类型 "), + }, + { + name: "string", + src: `{"Name":"Tom"}`, + wantVal: User{Name: "Tom"}, + }, + { + name: "string pointer", + src: func() string { + return `{"Name":"Tom"}` + }(), + wantVal: User{Name: "Tom"}, + }, + { + name: "bytes", + src: []byte(`{"Name":"Tom"}`), + wantVal: User{Name: "Tom"}, + }, + { + name: "bytes pointer", + src: func() *[]byte { + res := []byte(`{"Name":"Tom"}`) + return &res + }(), + wantVal: User{Name: "Tom"}, + }, + { + name: "sql.RawBytes", + src: sql.RawBytes(`{"Name":"Tom"}`), + wantVal: User{Name: "Tom"}, + }, + { + name: "sql.RawBytes pointer", + src: func() *sql.RawBytes { + res := sql.RawBytes(`{"Name":"Tom"}`) + return &res + }(), + wantVal: User{Name: "Tom"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + js := &JsonColumn[User]{} + err := js.Scan(tc.src) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantVal, js.Val) + assert.True(t, js.Valid) + }) + } +} + +func TestJsonColumn_ScanTypes(t *testing.T) { + jsSlice := JsonColumn[[]string]{} + err := jsSlice.Scan(`["a", "b", "c"]`) + assert.Nil(t, err) + assert.Equal(t, []string{"a", "b", "c"}, jsSlice.Val) + val, err := jsSlice.Value() + assert.Nil(t, err) + assert.Equal(t, []byte(`["a","b","c"]`), val) + + jsMap := JsonColumn[map[string]string]{} + err = jsMap.Scan(`{"a":"a value"}`) + assert.Nil(t, err) + val, err = jsMap.Value() + assert.Nil(t, err) + assert.Equal(t, []byte(`{"a":"a value"}`), val) +} + +type User struct { + Name string +} + +func ExampleJsonColumn_Value() { + js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}} + value, err := js.Value() + if err != nil { + fmt.Println(err) + } + fmt.Print(string(value.([]byte))) + // Output: + // {"Name":"Tom"} +} + +func ExampleJsonColumn_Scan() { + js := JsonColumn[User]{} + err := js.Scan(`{"Name":"Tom"}`) + if err != nil { + fmt.Println(err) + } + fmt.Print(js.Val) + // Output: + // {Tom} +}