diff --git a/array_handler.go b/array_handler.go new file mode 100644 index 0000000..0b91ecc --- /dev/null +++ b/array_handler.go @@ -0,0 +1,261 @@ +package postgres + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + + "gorm.io/gorm/schema" +) + +type PostgresArrayHandler struct{} + +type arrayScanner struct { + fieldType reflect.Type + value interface{} +} + +func (s *arrayScanner) Scan(src interface{}) error { + if src == nil { + s.value = reflect.MakeSlice(s.fieldType.Elem(), 0, 0).Interface() + return nil + } + + switch v := src.(type) { + case string: + // Remove the curly braces + str := strings.Trim(v, "{}") + + // Handle empty array + if str == "" { + s.value = reflect.MakeSlice(s.fieldType.Elem(), 0, 0).Interface() + return nil + } + + // Split the string into elements + elements := strings.Split(str, ",") + + // Create a new slice with the correct type + slice := reflect.MakeSlice(s.fieldType.Elem(), len(elements), len(elements)) + + // Convert each element to the correct type + for i, elem := range elements { + elem = strings.Trim(elem, "\"") // Remove quotes if present + switch s.fieldType.Elem().Elem().Kind() { + case reflect.String: + slice.Index(i).SetString(elem) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if val, err := strconv.ParseInt(elem, 10, 64); err == nil { + slice.Index(i).SetInt(val) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if val, err := strconv.ParseUint(elem, 10, 64); err == nil { + slice.Index(i).SetUint(val) + } + case reflect.Float32: + if val, err := strconv.ParseFloat(elem, 32); err == nil { + slice.Index(i).SetFloat(float64(val)) + } + case reflect.Float64: + if val, err := strconv.ParseFloat(elem, 64); err == nil { + slice.Index(i).SetFloat(val) + } + case reflect.Bool: + if val, err := strconv.ParseBool(elem); err == nil { + slice.Index(i).SetBool(val) + } + } + } + s.value = slice.Interface() + return nil + } + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %s", src, s.fieldType) +} + +func (s *arrayScanner) Value() (driver.Value, error) { + if s.value == nil { + return nil, nil + } + return s.value, nil +} + +func (h *PostgresArrayHandler) HandleArray(field *schema.Field) error { + oldValueOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValueOf(ctx, v) + if zero { + return value, zero + } + + return h.convertArrayToPostgres(value) + } + + // Mark the field as implementing Scanner interface + field.FieldType = reflect.PtrTo(field.FieldType) + + oldSet := field.Set + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + return h.handleArraySet(field, ctx, value, v, oldSet) + } + + // Add Scanner implementation + if _, ok := reflect.New(field.FieldType).Interface().(sql.Scanner); !ok { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &arrayScanner{ + fieldType: field.FieldType, + } + }, + } + } + + return nil +} + +func (h *PostgresArrayHandler) convertArrayToPostgres(value interface{}) (interface{}, bool) { + switch slice := value.(type) { + case []string: + return "{" + strings.Join(slice, ",") + "}", false + case []int: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatInt(int64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []int8: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatInt(int64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []int16: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatInt(int64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []int32: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatInt(int64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []int64: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatInt(v, 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []uint: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatUint(uint64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []uint16: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatUint(uint64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []uint32: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatUint(uint64(v), 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []uint64: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatUint(v, 10) + } + return "{" + strings.Join(strs, ",") + "}", false + case []float32: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatFloat(float64(v), 'f', -1, 32) + } + return "{" + strings.Join(strs, ",") + "}", false + case []float64: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatFloat(v, 'f', -1, 64) + } + return "{" + strings.Join(strs, ",") + "}", false + case []bool: + strs := make([]string, len(slice)) + for i, v := range slice { + strs[i] = strconv.FormatBool(v) + } + return "{" + strings.Join(strs, ",") + "}", false + } + return value, false +} + +func (h *PostgresArrayHandler) handleArraySet(field *schema.Field, ctx context.Context, value reflect.Value, v interface{}, oldSet func(context.Context, reflect.Value, interface{}) error) error { + if v == nil { + field.ReflectValueOf(ctx, value).Set(reflect.MakeSlice(field.FieldType.Elem(), 0, 0)) + return nil + } + + switch data := v.(type) { + case *arrayScanner: + if data.value != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data.value)) + } + return nil + case string: + // Remove the curly braces + str := strings.Trim(data, "{}") + + // Handle empty array + if str == "" { + field.ReflectValueOf(ctx, value).Set(reflect.MakeSlice(field.FieldType.Elem(), 0, 0)) + return nil + } + + // Split the string into elements + elements := strings.Split(str, ",") + + // Create a new slice with the correct type + slice := reflect.MakeSlice(field.FieldType.Elem(), len(elements), len(elements)) + + // Convert each element to the correct type + for i, elem := range elements { + elem = strings.Trim(elem, "\"") // Remove quotes if present + switch field.FieldType.Elem().Elem().Kind() { + case reflect.String: + slice.Index(i).SetString(elem) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if val, err := strconv.ParseInt(elem, 10, 64); err == nil { + slice.Index(i).SetInt(val) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if val, err := strconv.ParseUint(elem, 10, 64); err == nil { + slice.Index(i).SetUint(val) + } + case reflect.Float32: + if val, err := strconv.ParseFloat(elem, 32); err == nil { + slice.Index(i).SetFloat(val) + } + case reflect.Float64: + if val, err := strconv.ParseFloat(elem, 64); err == nil { + slice.Index(i).SetFloat(val) + } + case reflect.Bool: + if val, err := strconv.ParseBool(elem); err == nil { + slice.Index(i).SetBool(val) + } + } + } + field.ReflectValueOf(ctx, value).Set(slice) + return nil + default: + return oldSet(ctx, value, v) + } +} diff --git a/postgres.go b/postgres.go index e865b0f..44795a8 100644 --- a/postgres.go +++ b/postgres.go @@ -27,6 +27,7 @@ type Config struct { WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool + EnableArrayHandler bool Conn gorm.ConnPool } @@ -104,6 +105,25 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { } db.ConnPool = stdlib.OpenDB(*config) } + + if dialector.Config.EnableArrayHandler { + if err := db.Callback().Create().Before("gorm:create").Register("postgres:setup_array_handler", func(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, field := range db.Statement.Schema.Fields { + if field.TagSettings["ARRAY_FIELD"] == "true" { + handler := &PostgresArrayHandler{} + err := handler.HandleArray(field) + if err != nil { + return + } + } + } + } + }); err != nil { + return err + } + } + return } @@ -192,6 +212,28 @@ func (dialector Dialector) Explain(sql string, vars ...interface{}) string { } func (dialector Dialector) DataTypeOf(field *schema.Field) string { + // Need to change to schema.Array once https://github.com/go-gorm/postgres/pull/296 is released + if field.DataType == "array" { + elemKind := field.TagSettings["ELEM_TYPE"] + switch elemKind { + case "string": + field.Size = 0 // Let Postgres handle the size + return "text[]" + case "int", "int8", "int16", "int32", "int64": + return "integer[]" + case "uint", "uint16", "uint32", "uint64": + return "integer[]" + case "float32": + return "real[]" + case "float64": + return "double precision[]" + case "bool": + return "boolean[]" + default: + return "text[]" + } + } + switch field.DataType { case schema.Bool: return "boolean"