Skip to content

Commit

Permalink
fix: Columns and Values should recognize pointer values too (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
taronish-stytch authored Aug 22, 2023
1 parent cbcbcfc commit 9c197bf
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 10 deletions.
14 changes: 11 additions & 3 deletions columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func columnNames(model reflect.Value, strict bool, excluded ...string) []string
continue
}

if supportedColumnType(valField.Kind()) || isValidSqlValue(valField) {
if supportedColumnType(valField) || isValidSqlValue(valField) {
names = append(names, fieldName)
}
}
Expand Down Expand Up @@ -152,13 +152,16 @@ func reflectValue(v interface{}) (reflect.Value, error) {
return vVal, nil
}

func supportedColumnType(k reflect.Kind) bool {
switch k {
func supportedColumnType(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Interface,
reflect.String:
return true
case reflect.Ptr:
ptrVal := reflect.New(v.Type().Elem())
return supportedColumnType(ptrVal.Elem())
default:
return false
}
Expand All @@ -169,6 +172,11 @@ func isValidSqlValue(v reflect.Value) bool {
// 1. It returns true for sql.driver's type check for types like time.Time
// 2. It implements the driver.Valuer interface allowing conversion directly
// into sql statements
if v.Kind() == reflect.Ptr {
ptrVal := reflect.New(v.Type().Elem())
return isValidSqlValue(ptrVal.Elem())
}

if driver.IsValue(v.Interface()) {
return true
}
Expand Down
28 changes: 27 additions & 1 deletion columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,17 @@ func TestColumnsStoresOneCacheEntryPerInstance(t *testing.T) {
assert.Equal(t, 1, after-before, "Cache size grew unexpectedly")
}

func TestValuesWorkWithValidSqlValueTypes(t *testing.T) {
func TestColumnsReturnsStructTagsWithPointers(t *testing.T) {
type personUpdate struct {
Name *string `db:"name"`
}

cols, err := Columns(&personUpdate{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"name"}, cols)
}

func TestColumnsWorkWithValidSqlValueTypes(t *testing.T) {
type coupon struct {
Value int `db:"value"`
Expires time.Time `db:"expires"`
Expand All @@ -284,6 +294,18 @@ func TestValuesWorkWithValidSqlValueTypes(t *testing.T) {
assert.EqualValues(t, []string{"value", "expires"}, cols)
}

func TestColumnsWorkWithPointerValidSqlTypes(t *testing.T) {
type coupon struct {
Value int `db:"value"`
Expires *time.Time `db:"expires"`
}

c := &coupon{}
cols, err := Columns(c)
assert.NoError(t, err)
assert.EqualValues(t, []string{"value", "expires"}, cols)
}

type Pet struct {
Species string
Name string
Expand All @@ -310,3 +332,7 @@ func BenchmarkColumnsLargeStruct(b *testing.B) {
Columns(ls)
}
}

func ptr(s string) *string {
return &s
}
60 changes: 54 additions & 6 deletions example_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func exampleDB() *sql.DB {
);`,
`INSERT INTO person (id, name) VALUES (1, 'brett', 1);`,
`INSERT INTO person (id, name) VALUES (2, 'fred', 1);`,
`INSERT INTO person (id) VALUES (3);`,
)
}

Expand Down Expand Up @@ -136,6 +137,53 @@ func ExampleRowStrict() {
// {"ID":0,"Name":"brett"}
}

func ExampleRowPtr() {
db := exampleDB()
defer db.Close()
rows, err := db.Query("SELECT id,name FROM person where id = 3 LIMIT 1")
if err != nil {
panic(err)
}

var person struct {
ID int
Name *string `db:"name"`
}

err = scan.RowStrict(&person, rows)
if err != nil {
panic(err)
}

json.NewEncoder(os.Stdout).Encode(&person)
// Output:
// {"ID":0,"Name":null}
}

func ExampleRowPtrType() {
db := exampleDB()
defer db.Close()
rows, err := db.Query("SELECT id,name FROM person where id = 3 LIMIT 1")
if err != nil {
panic(err)
}

type NullableString *string
var person struct {
ID int
Name NullableString `db:"name"`
}

err = scan.RowStrict(&person, rows)
if err != nil {
panic(err)
}

json.NewEncoder(os.Stdout).Encode(&person)
// Output:
// {"ID":0,"Name":null}
}

func ExampleRow_scalar() {
db := exampleDB()
defer db.Close()
Expand Down Expand Up @@ -165,8 +213,8 @@ func ExampleRows() {
}

var persons []struct {
ID int `db:"id"`
Name string `db:"name"`
ID int `db:"id"`
Name *string `db:"name"`
}

err = scan.Rows(&persons, rows)
Expand All @@ -176,7 +224,7 @@ func ExampleRows() {

json.NewEncoder(os.Stdout).Encode(&persons)
// Output:
// [{"ID":1,"Name":"brett"},{"ID":2,"Name":"fred"}]
// [{"ID":1,"Name":"brett"},{"ID":2,"Name":"fred"},{"ID":3,"Name":null}]
}

func ExampleRowsStrict() {
Expand All @@ -189,7 +237,7 @@ func ExampleRowsStrict() {

var persons []struct {
ID int
Name string `db:"name"`
Name *string `db:"name"`
}

err = scan.Rows(&persons, rows)
Expand All @@ -199,13 +247,13 @@ func ExampleRowsStrict() {

json.NewEncoder(os.Stdout).Encode(&persons)
// Output:
// [{"ID":0,"Name":"brett"},{"ID":0,"Name":"fred"}]
// [{"ID":0,"Name":"brett"},{"ID":0,"Name":"fred"},{"ID":0,"Name":null}]
}

func ExampleRows_primitive() {
db := exampleDB()
defer db.Close()
rows, err := db.Query("SELECT name FROM person ORDER BY id ASC")
rows, err := db.Query("SELECT name FROM person WHERE name IS NOT NULL ORDER BY id ASC")
if err != nil {
panic(err)
}
Expand Down
40 changes: 40 additions & 0 deletions values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ func TestValuesScansDBTags(t *testing.T) {
assert.EqualValues(t, []interface{}{"Brett"}, vals)
}

func TestValuesScansPointerDBTags(t *testing.T) {
type person struct {
Name *string `db:"n"`
}

p := &person{Name: ptr("Brett")}
vals, err := Values([]string{"n"}, p)
require.NoError(t, err)

assert.EqualValues(t, []interface{}{ptr("Brett")}, vals)
}

func TestValuesReturnsNilPointers(t *testing.T) {
type person struct {
Name *string `db:"n"`
}

p := &person{Name: nil}
vals, err := Values([]string{"n"}, p)
require.NoError(t, err)

assert.EqualValues(t, []interface{}{(*string)(nil)}, vals)
}

func TestValuesScansNestedFields(t *testing.T) {
type Address struct {
Street string
Expand Down Expand Up @@ -124,6 +148,22 @@ func TestValuesValidSqlTypes(t *testing.T) {
assert.EqualValues(t, []interface{}{25, tNow}, vals)
}

func TestValuesValidPointerSqlTypes(t *testing.T) {
tNow := time.Now()
type coupon struct {
Value int
Expires *time.Time
}
c := &coupon{
Value: 25,
Expires: &tNow,
}

vals, err := Values([]string{"Value", "Expires"}, c)
require.NoError(t, err)
assert.EqualValues(t, []interface{}{25, &tNow}, vals)
}

func TestValuesDriverValuerImplementers(t *testing.T) {
type person struct {
Name string `db:"name"`
Expand Down

0 comments on commit 9c197bf

Please sign in to comment.