diff --git a/internal/postgres/utils.go b/internal/postgres/utils.go
index 0b63ccdfc8..1877d33fe7 100644
--- a/internal/postgres/utils.go
+++ b/internal/postgres/utils.go
@@ -5,9 +5,9 @@ import (
"encoding/json"
"fmt"
"reflect"
+ "strconv"
"strings"
- "github.com/gofrs/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
@@ -20,7 +20,71 @@ type PgxArray[T any] struct {
// properly handles scanning postgres arrays
func (a *PgxArray[T]) Scan(src any) error {
m := pgtype.NewMap()
- pgt, ok := m.TypeForName(strings.ToLower(a.colDataType))
+ // Register money types
+ m.RegisterType(&pgtype.Type{
+ Name: "money",
+ OID: 790,
+ Codec: pgtype.TextCodec{},
+ })
+ m.RegisterType(&pgtype.Type{
+ Name: "_money",
+ OID: 791,
+ Codec: &pgtype.ArrayCodec{
+ ElementType: &pgtype.Type{
+ Name: "money",
+ OID: 790,
+ Codec: pgtype.TextCodec{},
+ },
+ },
+ })
+
+ // Register UUID types
+ m.RegisterType(&pgtype.Type{
+ Name: "uuid",
+ OID: 2950, // UUID type OID
+ Codec: pgtype.TextCodec{},
+ })
+
+ m.RegisterType(&pgtype.Type{
+ Name: "_uuid",
+ OID: 2951,
+ Codec: &pgtype.ArrayCodec{
+ ElementType: &pgtype.Type{
+ Name: "uuid",
+ OID: 2950,
+ Codec: pgtype.TextCodec{},
+ },
+ },
+ })
+
+ // Register XML type
+ m.RegisterType(&pgtype.Type{
+ Name: "xml",
+ OID: 142,
+ Codec: pgtype.TextCodec{},
+ })
+
+ m.RegisterType(&pgtype.Type{
+ Name: "_xml",
+ OID: 143,
+ Codec: &pgtype.ArrayCodec{
+ ElementType: &pgtype.Type{
+ Name: "xml",
+ OID: 142,
+ Codec: pgtype.TextCodec{},
+ },
+ },
+ })
+
+ // Try to get the type by OID first if colDataType is numeric
+ var pgt *pgtype.Type
+ var ok bool
+
+ if oid, err := strconv.Atoi(a.colDataType); err == nil {
+ pgt, ok = m.TypeForOID(uint32(oid)) //nolint:gosec
+ } else {
+ pgt, ok = m.TypeForName(strings.ToLower(a.colDataType))
+ }
if !ok {
return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %s", a.colDataType)
}
@@ -34,13 +98,49 @@ func (a *PgxArray[T]) Scan(src any) error {
case []byte:
bufSrc = src
default:
- bufSrc = []byte(fmt.Sprint(bufSrc))
+ bufSrc = []byte(fmt.Sprint(src))
}
}
return m.Scan(pgt.OID, pgtype.TextFormatCode, bufSrc, v)
}
+type NullableJSON struct {
+ json.RawMessage
+ Valid bool
+}
+
+// Nullable JSON scanner
+func (n *NullableJSON) Scan(value any) error {
+ if value == nil {
+ n.RawMessage, n.Valid = nil, false
+ return nil
+ }
+
+ n.Valid = true
+ switch v := value.(type) {
+ case []byte:
+ n.RawMessage = json.RawMessage(v)
+ return nil
+ case string:
+ n.RawMessage = json.RawMessage(v)
+ return nil
+ default:
+ return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", value, n.RawMessage)
+ }
+}
+
+func (n *NullableJSON) Unmarshal() (any, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ var js any
+ if err := json.Unmarshal(n.RawMessage, &js); err != nil {
+ return nil, err
+ }
+ return js, nil
+}
+
func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) {
columnNames, err := rows.Columns()
if err != nil {
@@ -52,51 +152,51 @@ func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) {
return nil, err
}
- columnDbTypes := []string{}
- for _, c := range cTypes {
- columnDbTypes = append(columnDbTypes, c.DatabaseTypeName())
- }
-
values := make([]any, len(columnNames))
- valuesWrapped := make([]any, 0, len(columnNames))
+ scanTargets := make([]any, 0, len(columnNames))
for i := range values {
- ctype := cTypes[i]
- if IsPgArrayType(ctype.DatabaseTypeName()) {
- // use custom array type scanner
- values[i] = &PgxArray[any]{
- colDataType: ctype.DatabaseTypeName(),
- }
- valuesWrapped = append(valuesWrapped, values[i])
- } else {
- valuesWrapped = append(valuesWrapped, &values[i])
+ dbTypeName := cTypes[i].DatabaseTypeName()
+ switch {
+ case isXmlDataType(dbTypeName):
+ values[i] = &sql.NullString{}
+ scanTargets = append(scanTargets, values[i])
+ case IsJsonPgDataType(dbTypeName):
+ values[i] = &NullableJSON{}
+ scanTargets = append(scanTargets, values[i])
+ case isPgxPgArrayType(dbTypeName):
+ values[i] = &PgxArray[any]{colDataType: dbTypeName}
+ scanTargets = append(scanTargets, values[i])
+ default:
+ scanTargets = append(scanTargets, &values[i])
}
}
- if err := rows.Scan(valuesWrapped...); err != nil {
+ if err := rows.Scan(scanTargets...); err != nil {
return nil, err
}
- jObj := parsePgRowValues(values, columnNames, columnDbTypes)
+ jObj := parsePgRowValues(values, columnNames)
return jObj, nil
}
-func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[string]any {
+func parsePgRowValues(values []any, columnNames []string) map[string]any {
jObj := map[string]any{}
for i, v := range values {
col := columnNames[i]
- ctype := columnDbTypes[i]
switch t := v.(type) {
- case []byte:
- if IsJsonPgDataType(ctype) {
- var js any
- if err := json.Unmarshal(t, &js); err == nil {
- jObj[col] = js
- continue
- }
- } else if isBinaryDataType(ctype) {
- jObj[col] = t
- continue
+ case nil:
+ jObj[col] = t
+ case *sql.NullString:
+ var val any = nil
+ if t.Valid {
+ val = t.String
+ }
+ jObj[col] = val
+ case *NullableJSON:
+ js, err := t.Unmarshal()
+ if err != nil {
+ js = t
}
- jObj[col] = string(t)
+ jObj[col] = js
case *PgxArray[any]:
jObj[col] = pgArrayToGoSlice(t)
default:
@@ -106,39 +206,25 @@ func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[str
return jObj
}
-func isBinaryDataType(colDataType string) bool {
- return strings.EqualFold(colDataType, "bytea")
+func isXmlDataType(colDataType string) bool {
+ return strings.EqualFold(colDataType, "xml")
}
func IsJsonPgDataType(dataType string) bool {
return strings.EqualFold(dataType, "json") || strings.EqualFold(dataType, "jsonb")
}
-
-func isJsonArrayPgDataType(dataType string) bool {
- return strings.EqualFold(dataType, "_json") || strings.EqualFold(dataType, "_jsonb")
-}
-
-func isPgUuidArray(colDataType string) bool {
- return strings.EqualFold(colDataType, "_uuid")
-}
-
-func isPgXmlArray(colDataType string) bool {
- return strings.EqualFold(colDataType, "_xml")
-}
-
-func IsPgArrayType(dbTypeName string) bool {
- return strings.HasPrefix(dbTypeName, "_")
+func isPgxPgArrayType(dbTypeName string) bool {
+ return strings.HasPrefix(dbTypeName, "_") || dbTypeName == "791"
}
func IsPgArrayColumnDataType(colDataType string) bool {
- return strings.Contains(colDataType, "[]")
+ return strings.HasSuffix(colDataType, "[]")
}
func pgArrayToGoSlice(array *PgxArray[any]) any {
if array.Elements == nil {
return nil
}
- goSlice := convertArrayToGoType(array)
dim := array.Dimensions()
if len(dim) > 1 {
@@ -146,60 +232,9 @@ func pgArrayToGoSlice(array *PgxArray[any]) any {
for _, d := range dim {
dims = append(dims, int(d.Length))
}
- return CreateMultiDimSlice(dims, goSlice)
- }
- return goSlice
-}
-
-func convertArrayToGoType(array *PgxArray[any]) []any {
- if !isJsonArrayPgDataType(array.colDataType) {
- if isPgUuidArray(array.colDataType) {
- return convertBytesToUuidSlice(array.Elements)
- }
- if isPgXmlArray(array.colDataType) {
- return convertBytesToStringSlice(array.Elements)
- }
- return array.Elements
- }
-
- var newArray []any
- for _, e := range array.Elements {
- jsonBits, ok := e.([]byte)
- if !ok {
- newArray = append(newArray, e)
- continue
- }
-
- var js any
- err := json.Unmarshal(jsonBits, &js)
- if err != nil {
- newArray = append(newArray, e)
- } else {
- newArray = append(newArray, js)
- }
- }
-
- return newArray
-}
-
-func convertBytesToStringSlice(bytes []any) []any {
- stringSlice := []any{}
- for _, el := range bytes {
- if bits, ok := el.([]byte); ok {
- stringSlice = append(stringSlice, string(bits))
- }
- }
- return stringSlice
-}
-
-func convertBytesToUuidSlice(uuids []any) []any {
- uuidSlice := []any{}
- for _, el := range uuids {
- if id, ok := el.([16]uint8); ok {
- uuidSlice = append(uuidSlice, uuid.UUID(id).String())
- }
+ return CreateMultiDimSlice(dims, array.Elements)
}
- return uuidSlice
+ return array.Elements
}
// converts flat slice to multi-dimensional slice
diff --git a/internal/postgres/utils_test.go b/internal/postgres/utils_test.go
index 1a67469e99..65a1ce8979 100644
--- a/internal/postgres/utils_test.go
+++ b/internal/postgres/utils_test.go
@@ -1,9 +1,10 @@
package postgres
import (
+ "database/sql"
+ "encoding/json"
"testing"
- "github.com/gofrs/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require"
)
@@ -168,39 +169,32 @@ func Test_FormatPgArrayLiteral(t *testing.T) {
func Test_parsePgRowValues(t *testing.T) {
t.Run("Multiple Columns", func(t *testing.T) {
binaryData := []byte{0x01, 0x02, 0x03}
- xmlData := []byte("value")
+ xmlStr := "value"
uuidValue := "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"
+ xmlVal := &sql.NullString{String: xmlStr, Valid: true}
+ jsonVal := &NullableJSON{RawMessage: json.RawMessage(`{"key": "value"}`), Valid: true}
+
values := []any{
"Hello",
int64(42),
true,
nil,
- []byte(`{"key": "value"}`),
+ jsonVal,
&PgxArray[any]{
Array: pgtype.Array[any]{Elements: []any{1, 2, 3}},
colDataType: "_integer",
},
binaryData,
- xmlData,
+ xmlVal,
uuidValue,
}
columnNames := []string{
"text_col", "int_col", "bool_col", "nil_col", "json_col", "array_col",
"binary_col", "xml_col", "uuid_col",
}
- cTypes := []string{
- "text",
- "integer",
- "boolean",
- "text",
- "json",
- "_integer",
- "bytea",
- "xml",
- "uuid",
- }
- result := parsePgRowValues(values, columnNames, cTypes)
+
+ result := parsePgRowValues(values, columnNames)
expected := map[string]any{
"text_col": "Hello",
"int_col": int64(42),
@@ -209,18 +203,24 @@ func Test_parsePgRowValues(t *testing.T) {
"json_col": map[string]any{"key": "value"},
"array_col": []any{1, 2, 3},
"binary_col": binaryData,
- "xml_col": string(xmlData), // Assuming XML is converted to string
+ "xml_col": xmlStr,
"uuid_col": uuidValue,
}
require.Equal(t, expected, result)
})
t.Run("JSON Columns", func(t *testing.T) {
- values := []any{[]byte(`"Hello"`), []byte(`true`), []byte(`null`), []byte(`42`), []byte(`{"items": ["book", "pen"], "count": 2, "in_stock": true}`), []byte(`[1,2,3]`)}
+ values := []any{
+ &NullableJSON{RawMessage: json.RawMessage(`"Hello"`), Valid: true},
+ &NullableJSON{RawMessage: json.RawMessage(`true`), Valid: true},
+ &NullableJSON{Valid: false},
+ &NullableJSON{RawMessage: json.RawMessage(`42`), Valid: true},
+ &NullableJSON{RawMessage: json.RawMessage(`{"items": ["book", "pen"], "count": 2, "in_stock": true}`), Valid: true},
+ &NullableJSON{RawMessage: json.RawMessage(`[1,2,3]`), Valid: true},
+ }
columnNames := []string{"text_col", "bool_col", "null_col", "int_col", "json_col", "array_col"}
- cTypes := []string{"json", "json", "json", "json", "json", "json"}
- result := parsePgRowValues(values, columnNames, cTypes)
+ result := parsePgRowValues(values, columnNames)
expected := map[string]any{
"text_col": "Hello",
@@ -236,10 +236,10 @@ func Test_parsePgRowValues(t *testing.T) {
t.Run("Multiple Array Columns", func(t *testing.T) {
binaryData1 := []byte{0x01, 0x02, 0x03}
binaryData2 := []byte{0x04, 0x05, 0x06}
- xmlData1 := []byte("value1")
- xmlData2 := []byte("value2")
- uuidValue1 := [16]uint8{0xa0, 0xee, 0xbc, 0x99, 0x9c, 0x0b, 0x4e, 0xf8, 0xbb, 0x6d, 0x6b, 0xb9, 0xbd, 0x38, 0x0a, 0x11}
- uuidValue2 := [16]uint8{0xb0, 0xee, 0xbc, 0x99, 0x9c, 0x0b, 0x4e, 0xf8, 0xbb, 0x6d, 0x6b, 0xb9, 0xbd, 0x38, 0x0a, 0x22}
+ xmlData1 := "value1"
+ xmlData2 := "value2"
+ uuidValue1 := "160075f6-4d6e-4040-b287-bd43677464fa"
+ uuidValue2 := "5f4a4b96-a74e-4502-b05b-1d96fba90657"
values := []any{
&PgxArray[any]{
@@ -255,7 +255,9 @@ func Test_parsePgRowValues(t *testing.T) {
colDataType: "_boolean",
},
&PgxArray[any]{
- Array: pgtype.Array[any]{Elements: []any{[]byte(`{"key": "value1"}`), []byte(`{"key": "value2"}`)}},
+ Array: pgtype.Array[any]{Elements: []any{
+ map[string]any{"key": "value1"}, map[string]any{"key": "value2"},
+ }},
colDataType: "_json",
},
&PgxArray[any]{
@@ -263,7 +265,10 @@ func Test_parsePgRowValues(t *testing.T) {
colDataType: "_bytea",
},
&PgxArray[any]{
- Array: pgtype.Array[any]{Elements: []any{xmlData1, xmlData2}},
+ Array: pgtype.Array[any]{Elements: []any{
+ xmlData1,
+ xmlData2,
+ }},
colDataType: "_xml",
},
&PgxArray[any]{
@@ -281,12 +286,7 @@ func Test_parsePgRowValues(t *testing.T) {
"binary_array", "xml_array", "uuid_array", "multidim_array",
}
- cTypes := []string{
- "_text", "_integer", "_boolean", "_json",
- "_bytea", "_xml", "_uuid", "_integer[]",
- }
-
- result := parsePgRowValues(values, columnNames, cTypes)
+ result := parsePgRowValues(values, columnNames)
expected := map[string]any{
"text_array": []any{"Hello", "World"},
@@ -294,8 +294,8 @@ func Test_parsePgRowValues(t *testing.T) {
"bool_array": []any{true, false},
"json_array": []any{map[string]any{"key": "value1"}, map[string]any{"key": "value2"}},
"binary_array": []any{binaryData1, binaryData2},
- "xml_array": []any{string(xmlData1), string(xmlData2)},
- "uuid_array": []any{uuid.UUID(uuidValue1).String(), uuid.UUID(uuidValue2).String()},
+ "xml_array": []any{xmlData1, xmlData2},
+ "uuid_array": []any{uuidValue1, uuidValue2},
"multidim_array": []any{[]any{1, 2}, []any{3, 4}},
}
@@ -305,4 +305,72 @@ func Test_parsePgRowValues(t *testing.T) {
require.ElementsMatch(t, actual, expectedArray)
}
})
+
+ t.Run("Null Values", func(t *testing.T) {
+ values := []any{
+ &sql.NullString{Valid: false},
+ &NullableJSON{Valid: false},
+ }
+ columnNames := []string{"null_string", "null_json"}
+
+ result := parsePgRowValues(values, columnNames)
+
+ expected := map[string]any{
+ "null_string": nil,
+ "null_json": nil,
+ }
+ require.Equal(t, expected, result)
+ })
+}
+
+func TestNullableJSON_Unmarshal(t *testing.T) {
+ tests := []struct {
+ name string
+ json NullableJSON
+ want any
+ wantErr bool
+ }{
+ {
+ name: "Invalid JSON",
+ json: NullableJSON{Valid: false},
+ want: nil,
+ wantErr: false,
+ },
+ {
+ name: "Valid string",
+ json: NullableJSON{RawMessage: json.RawMessage(`"test"`), Valid: true},
+ want: "test",
+ wantErr: false,
+ },
+ {
+ name: "Valid number",
+ json: NullableJSON{RawMessage: json.RawMessage(`42`), Valid: true},
+ want: float64(42),
+ wantErr: false,
+ },
+ {
+ name: "Valid object",
+ json: NullableJSON{RawMessage: json.RawMessage(`{"key":"value"}`), Valid: true},
+ want: map[string]any{"key": "value"},
+ wantErr: false,
+ },
+ {
+ name: "Invalid JSON content",
+ json: NullableJSON{RawMessage: json.RawMessage(`{invalid}`), Valid: true},
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.json.Unmarshal()
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tt.want, got)
+ })
+ }
}
diff --git a/worker/pkg/benthos/sql/json_processor.go b/worker/pkg/benthos/sql/json_processor.go
index 25270a444d..9d0cc1ea95 100644
--- a/worker/pkg/benthos/sql/json_processor.go
+++ b/worker/pkg/benthos/sql/json_processor.go
@@ -5,9 +5,9 @@ import (
"encoding/binary"
"encoding/json"
"strconv"
- "strings"
"github.com/lib/pq"
+ pgutil "github.com/nucleuscloud/neosync/internal/postgres"
"github.com/warpstreamlabs/bento/public/service"
)
@@ -83,7 +83,7 @@ func (p *jsonToSqlProcessor) transform(path string, root any) any {
if !ok {
return v
}
- if isPgArray(datatype) {
+ if pgutil.IsPgArrayColumnDataType(datatype) {
pgarray, err := processPgArray(v, datatype)
if err != nil {
p.logger.Errorf("unable to process PG Array: %w", err)
@@ -134,10 +134,6 @@ func processPgArray(bits []byte, datatype string) (any, error) {
}
}
-func isPgArray(datatype string) bool {
- return strings.HasSuffix(datatype, "[]")
-}
-
// handles case where json strings are not quoted
func getValidJson(jsonData []byte) ([]byte, error) {
isValidJson := json.Valid(jsonData)