Skip to content

Commit

Permalink
Merge pull request #7 from flycash/dev
Browse files Browse the repository at this point in the history
sql: 支持 JsonColumn
  • Loading branch information
flycash authored Jun 28, 2022
2 parents 81d8a8f + 686daa0 commit 881223f
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# 开发中
[ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6)
[ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6)
[sql: 支持 JsonColumn](https://github.com/gotomicro/ekit/pull/7)
65 changes: 65 additions & 0 deletions sql/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2021 gotomicro
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
)

// JsonColumn 代表存储字段的 json 类型
// 主要用于没有提供默认 json 类型的数据库
// T 可以是结构体,也可以是切片或者 map
// 一切可以被 json 库所处理的类型都能被用作 T
type JsonColumn[T any] struct {
Val T
Valid bool
}

// Value 返回一个 json 串。类型是 []byte
func (j JsonColumn[T]) Value() (driver.Value, error) {
if !j.Valid {
return nil, nil
}
return json.Marshal(j.Val)
}

// Scan 将 src 转化为对象
// src 的类型必须是 []byte, *[]byte, string, sql.RawBytes, *sql.RawBytes 之一
func (j *JsonColumn[T]) Scan(src any) error {
var bs []byte
switch val := src.(type) {
case []byte:
bs = val
case *[]byte:
bs = *val
case string:
bs = []byte(val)
case sql.RawBytes:
bs = val
case *sql.RawBytes:
bs = *val
default:
return fmt.Errorf("ekit:JsonColumn.Scan 不支持 src 类型 %v", src)
}

if err := json.Unmarshal(bs, &j.Val); err != nil {
return err
}
j.Valid = true
return nil
}
141 changes: 141 additions & 0 deletions sql/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2021 gotomicro
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"database/sql"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)

func TestJsonColumn_Value(t *testing.T) {
js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}}
value, err := js.Value()
assert.Nil(t, err)
assert.Equal(t, []byte(`{"Name":"Tom"}`), value)
js = JsonColumn[User]{}
value, err = js.Value()
assert.Nil(t, err)
assert.Nil(t, value)
}

func TestJsonColumn_Scan(t *testing.T) {
testCases := []struct {
name string
src any
wantErr error
wantVal User
}{
{
name: "nil",
wantErr: errors.New("ekit:JsonColumn.Scan 不支持 src 类型 <nil>"),
},
{
name: "string",
src: `{"Name":"Tom"}`,
wantVal: User{Name: "Tom"},
},
{
name: "string pointer",
src: func() string {
return `{"Name":"Tom"}`
}(),
wantVal: User{Name: "Tom"},
},
{
name: "bytes",
src: []byte(`{"Name":"Tom"}`),
wantVal: User{Name: "Tom"},
},
{
name: "bytes pointer",
src: func() *[]byte {
res := []byte(`{"Name":"Tom"}`)
return &res
}(),
wantVal: User{Name: "Tom"},
},
{
name: "sql.RawBytes",
src: sql.RawBytes(`{"Name":"Tom"}`),
wantVal: User{Name: "Tom"},
},
{
name: "sql.RawBytes pointer",
src: func() *sql.RawBytes {
res := sql.RawBytes(`{"Name":"Tom"}`)
return &res
}(),
wantVal: User{Name: "Tom"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
js := &JsonColumn[User]{}
err := js.Scan(tc.src)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantVal, js.Val)
assert.True(t, js.Valid)
})
}
}

func TestJsonColumn_ScanTypes(t *testing.T) {
jsSlice := JsonColumn[[]string]{}
err := jsSlice.Scan(`["a", "b", "c"]`)
assert.Nil(t, err)
assert.Equal(t, []string{"a", "b", "c"}, jsSlice.Val)
val, err := jsSlice.Value()
assert.Nil(t, err)
assert.Equal(t, []byte(`["a","b","c"]`), val)

jsMap := JsonColumn[map[string]string]{}
err = jsMap.Scan(`{"a":"a value"}`)
assert.Nil(t, err)
val, err = jsMap.Value()
assert.Nil(t, err)
assert.Equal(t, []byte(`{"a":"a value"}`), val)
}

type User struct {
Name string
}

func ExampleJsonColumn_Value() {
js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}}
value, err := js.Value()
if err != nil {
fmt.Println(err)
}
fmt.Print(string(value.([]byte)))
// Output:
// {"Name":"Tom"}
}

func ExampleJsonColumn_Scan() {
js := JsonColumn[User]{}
err := js.Scan(`{"Name":"Tom"}`)
if err != nil {
fmt.Println(err)
}
fmt.Print(js.Val)
// Output:
// {Tom}
}

0 comments on commit 881223f

Please sign in to comment.