Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle postgres array #296

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions array_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package postgres

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"sync"

"gorm.io/gorm/schema"
)

type PostgresArrayHandler struct{}

type arrayScanner struct {
fieldType reflect.Type
value interface{}
}

func (s *arrayScanner) Scan(src interface{}) error {
if src == nil {
s.value = reflect.MakeSlice(s.fieldType.Elem(), 0, 0).Interface()
return nil
}

switch v := src.(type) {
case string:
// Remove the curly braces
str := strings.Trim(v, "{}")

// Handle empty array
if str == "" {
s.value = reflect.MakeSlice(s.fieldType.Elem(), 0, 0).Interface()
return nil
}

// Split the string into elements
elements := strings.Split(str, ",")

// Create a new slice with the correct type
slice := reflect.MakeSlice(s.fieldType.Elem(), len(elements), len(elements))

// Convert each element to the correct type
for i, elem := range elements {
elem = strings.Trim(elem, "\"") // Remove quotes if present
switch s.fieldType.Elem().Elem().Kind() {
case reflect.String:
slice.Index(i).SetString(elem)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if val, err := strconv.ParseInt(elem, 10, 64); err == nil {
slice.Index(i).SetInt(val)
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if val, err := strconv.ParseUint(elem, 10, 64); err == nil {
slice.Index(i).SetUint(val)
}
case reflect.Float32:
if val, err := strconv.ParseFloat(elem, 32); err == nil {
slice.Index(i).SetFloat(float64(val))
}
case reflect.Float64:
if val, err := strconv.ParseFloat(elem, 64); err == nil {
slice.Index(i).SetFloat(val)
}
case reflect.Bool:
if val, err := strconv.ParseBool(elem); err == nil {
slice.Index(i).SetBool(val)
}
}
}
s.value = slice.Interface()
return nil
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %s", src, s.fieldType)
}

func (s *arrayScanner) Value() (driver.Value, error) {
if s.value == nil {
return nil, nil
}
return s.value, nil
}

func (h *PostgresArrayHandler) HandleArray(field *schema.Field) error {
oldValueOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValueOf(ctx, v)
if zero {
return value, zero
}

return h.convertArrayToPostgres(value)
}

// Mark the field as implementing Scanner interface
field.FieldType = reflect.PtrTo(field.FieldType)

oldSet := field.Set
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
return h.handleArraySet(field, ctx, value, v, oldSet)
}

// Add Scanner implementation
if _, ok := reflect.New(field.FieldType).Interface().(sql.Scanner); !ok {
field.NewValuePool = &sync.Pool{
New: func() interface{} {
return &arrayScanner{
fieldType: field.FieldType,
}
},
}
}

return nil
}

func (h *PostgresArrayHandler) convertArrayToPostgres(value interface{}) (interface{}, bool) {
switch slice := value.(type) {
case []string:
return "{" + strings.Join(slice, ",") + "}", false
case []int:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []int8:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []int16:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []int32:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []int64:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatInt(v, 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []uint:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatUint(uint64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []uint16:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatUint(uint64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []uint32:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatUint(uint64(v), 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []uint64:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatUint(v, 10)
}
return "{" + strings.Join(strs, ",") + "}", false
case []float32:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatFloat(float64(v), 'f', -1, 32)
}
return "{" + strings.Join(strs, ",") + "}", false
case []float64:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatFloat(v, 'f', -1, 64)
}
return "{" + strings.Join(strs, ",") + "}", false
case []bool:
strs := make([]string, len(slice))
for i, v := range slice {
strs[i] = strconv.FormatBool(v)
}
return "{" + strings.Join(strs, ",") + "}", false
}
return value, false
}

func (h *PostgresArrayHandler) handleArraySet(field *schema.Field, ctx context.Context, value reflect.Value, v interface{}, oldSet func(context.Context, reflect.Value, interface{}) error) error {
if v == nil {
field.ReflectValueOf(ctx, value).Set(reflect.MakeSlice(field.FieldType.Elem(), 0, 0))
return nil
}

switch data := v.(type) {
case *arrayScanner:
if data.value != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data.value))
}
return nil
case string:
// Remove the curly braces
str := strings.Trim(data, "{}")

// Handle empty array
if str == "" {
field.ReflectValueOf(ctx, value).Set(reflect.MakeSlice(field.FieldType.Elem(), 0, 0))
return nil
}

// Split the string into elements
elements := strings.Split(str, ",")

// Create a new slice with the correct type
slice := reflect.MakeSlice(field.FieldType.Elem(), len(elements), len(elements))

// Convert each element to the correct type
for i, elem := range elements {
elem = strings.Trim(elem, "\"") // Remove quotes if present
switch field.FieldType.Elem().Elem().Kind() {
case reflect.String:
slice.Index(i).SetString(elem)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if val, err := strconv.ParseInt(elem, 10, 64); err == nil {
slice.Index(i).SetInt(val)
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if val, err := strconv.ParseUint(elem, 10, 64); err == nil {
slice.Index(i).SetUint(val)
}
case reflect.Float32:
if val, err := strconv.ParseFloat(elem, 32); err == nil {
slice.Index(i).SetFloat(val)
}
case reflect.Float64:
if val, err := strconv.ParseFloat(elem, 64); err == nil {
slice.Index(i).SetFloat(val)
}
case reflect.Bool:
if val, err := strconv.ParseBool(elem); err == nil {
slice.Index(i).SetBool(val)
}
}
}
field.ReflectValueOf(ctx, value).Set(slice)
return nil
default:
return oldSet(ctx, value, v)
}
}
42 changes: 42 additions & 0 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Config struct {
WithoutQuotingCheck bool
PreferSimpleProtocol bool
WithoutReturning bool
EnableArrayHandler bool
Conn gorm.ConnPool
}

Expand Down Expand Up @@ -104,6 +105,25 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
}
db.ConnPool = stdlib.OpenDB(*config)
}

if dialector.Config.EnableArrayHandler {
if err := db.Callback().Create().Before("gorm:create").Register("postgres:setup_array_handler", func(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, field := range db.Statement.Schema.Fields {
if field.TagSettings["ARRAY_FIELD"] == "true" {
handler := &PostgresArrayHandler{}
err := handler.HandleArray(field)
if err != nil {
return
}
}
}
}
}); err != nil {
return err
}
}

return
}

Expand Down Expand Up @@ -192,6 +212,28 @@ func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
}

func (dialector Dialector) DataTypeOf(field *schema.Field) string {
// Need to change to schema.Array once https://github.com/go-gorm/postgres/pull/296 is released
if field.DataType == "array" {
elemKind := field.TagSettings["ELEM_TYPE"]
switch elemKind {
case "string":
field.Size = 0 // Let Postgres handle the size
return "text[]"
case "int", "int8", "int16", "int32", "int64":
return "integer[]"
case "uint", "uint16", "uint32", "uint64":
return "integer[]"
case "float32":
return "real[]"
case "float64":
return "double precision[]"
case "bool":
return "boolean[]"
default:
return "text[]"
}
}

switch field.DataType {
case schema.Bool:
return "boolean"
Expand Down
Loading