diff --git a/internal/postgres/utils.go b/internal/postgres/utils.go index 0b63ccdfc8..1877d33fe7 100644 --- a/internal/postgres/utils.go +++ b/internal/postgres/utils.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" "reflect" + "strconv" "strings" - "github.com/gofrs/uuid" "github.com/jackc/pgx/v5/pgtype" ) @@ -20,7 +20,71 @@ type PgxArray[T any] struct { // properly handles scanning postgres arrays func (a *PgxArray[T]) Scan(src any) error { m := pgtype.NewMap() - pgt, ok := m.TypeForName(strings.ToLower(a.colDataType)) + // Register money types + m.RegisterType(&pgtype.Type{ + Name: "money", + OID: 790, + Codec: pgtype.TextCodec{}, + }) + m.RegisterType(&pgtype.Type{ + Name: "_money", + OID: 791, + Codec: &pgtype.ArrayCodec{ + ElementType: &pgtype.Type{ + Name: "money", + OID: 790, + Codec: pgtype.TextCodec{}, + }, + }, + }) + + // Register UUID types + m.RegisterType(&pgtype.Type{ + Name: "uuid", + OID: 2950, // UUID type OID + Codec: pgtype.TextCodec{}, + }) + + m.RegisterType(&pgtype.Type{ + Name: "_uuid", + OID: 2951, + Codec: &pgtype.ArrayCodec{ + ElementType: &pgtype.Type{ + Name: "uuid", + OID: 2950, + Codec: pgtype.TextCodec{}, + }, + }, + }) + + // Register XML type + m.RegisterType(&pgtype.Type{ + Name: "xml", + OID: 142, + Codec: pgtype.TextCodec{}, + }) + + m.RegisterType(&pgtype.Type{ + Name: "_xml", + OID: 143, + Codec: &pgtype.ArrayCodec{ + ElementType: &pgtype.Type{ + Name: "xml", + OID: 142, + Codec: pgtype.TextCodec{}, + }, + }, + }) + + // Try to get the type by OID first if colDataType is numeric + var pgt *pgtype.Type + var ok bool + + if oid, err := strconv.Atoi(a.colDataType); err == nil { + pgt, ok = m.TypeForOID(uint32(oid)) //nolint:gosec + } else { + pgt, ok = m.TypeForName(strings.ToLower(a.colDataType)) + } if !ok { return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %s", a.colDataType) } @@ -34,13 +98,49 @@ func (a *PgxArray[T]) Scan(src any) error { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = []byte(fmt.Sprint(src)) } } return m.Scan(pgt.OID, pgtype.TextFormatCode, bufSrc, v) } +type NullableJSON struct { + json.RawMessage + Valid bool +} + +// Nullable JSON scanner +func (n *NullableJSON) Scan(value any) error { + if value == nil { + n.RawMessage, n.Valid = nil, false + return nil + } + + n.Valid = true + switch v := value.(type) { + case []byte: + n.RawMessage = json.RawMessage(v) + return nil + case string: + n.RawMessage = json.RawMessage(v) + return nil + default: + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", value, n.RawMessage) + } +} + +func (n *NullableJSON) Unmarshal() (any, error) { + if !n.Valid { + return nil, nil + } + var js any + if err := json.Unmarshal(n.RawMessage, &js); err != nil { + return nil, err + } + return js, nil +} + func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) { columnNames, err := rows.Columns() if err != nil { @@ -52,51 +152,51 @@ func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) { return nil, err } - columnDbTypes := []string{} - for _, c := range cTypes { - columnDbTypes = append(columnDbTypes, c.DatabaseTypeName()) - } - values := make([]any, len(columnNames)) - valuesWrapped := make([]any, 0, len(columnNames)) + scanTargets := make([]any, 0, len(columnNames)) for i := range values { - ctype := cTypes[i] - if IsPgArrayType(ctype.DatabaseTypeName()) { - // use custom array type scanner - values[i] = &PgxArray[any]{ - colDataType: ctype.DatabaseTypeName(), - } - valuesWrapped = append(valuesWrapped, values[i]) - } else { - valuesWrapped = append(valuesWrapped, &values[i]) + dbTypeName := cTypes[i].DatabaseTypeName() + switch { + case isXmlDataType(dbTypeName): + values[i] = &sql.NullString{} + scanTargets = append(scanTargets, values[i]) + case IsJsonPgDataType(dbTypeName): + values[i] = &NullableJSON{} + scanTargets = append(scanTargets, values[i]) + case isPgxPgArrayType(dbTypeName): + values[i] = &PgxArray[any]{colDataType: dbTypeName} + scanTargets = append(scanTargets, values[i]) + default: + scanTargets = append(scanTargets, &values[i]) } } - if err := rows.Scan(valuesWrapped...); err != nil { + if err := rows.Scan(scanTargets...); err != nil { return nil, err } - jObj := parsePgRowValues(values, columnNames, columnDbTypes) + jObj := parsePgRowValues(values, columnNames) return jObj, nil } -func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[string]any { +func parsePgRowValues(values []any, columnNames []string) map[string]any { jObj := map[string]any{} for i, v := range values { col := columnNames[i] - ctype := columnDbTypes[i] switch t := v.(type) { - case []byte: - if IsJsonPgDataType(ctype) { - var js any - if err := json.Unmarshal(t, &js); err == nil { - jObj[col] = js - continue - } - } else if isBinaryDataType(ctype) { - jObj[col] = t - continue + case nil: + jObj[col] = t + case *sql.NullString: + var val any = nil + if t.Valid { + val = t.String + } + jObj[col] = val + case *NullableJSON: + js, err := t.Unmarshal() + if err != nil { + js = t } - jObj[col] = string(t) + jObj[col] = js case *PgxArray[any]: jObj[col] = pgArrayToGoSlice(t) default: @@ -106,39 +206,25 @@ func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[str return jObj } -func isBinaryDataType(colDataType string) bool { - return strings.EqualFold(colDataType, "bytea") +func isXmlDataType(colDataType string) bool { + return strings.EqualFold(colDataType, "xml") } func IsJsonPgDataType(dataType string) bool { return strings.EqualFold(dataType, "json") || strings.EqualFold(dataType, "jsonb") } - -func isJsonArrayPgDataType(dataType string) bool { - return strings.EqualFold(dataType, "_json") || strings.EqualFold(dataType, "_jsonb") -} - -func isPgUuidArray(colDataType string) bool { - return strings.EqualFold(colDataType, "_uuid") -} - -func isPgXmlArray(colDataType string) bool { - return strings.EqualFold(colDataType, "_xml") -} - -func IsPgArrayType(dbTypeName string) bool { - return strings.HasPrefix(dbTypeName, "_") +func isPgxPgArrayType(dbTypeName string) bool { + return strings.HasPrefix(dbTypeName, "_") || dbTypeName == "791" } func IsPgArrayColumnDataType(colDataType string) bool { - return strings.Contains(colDataType, "[]") + return strings.HasSuffix(colDataType, "[]") } func pgArrayToGoSlice(array *PgxArray[any]) any { if array.Elements == nil { return nil } - goSlice := convertArrayToGoType(array) dim := array.Dimensions() if len(dim) > 1 { @@ -146,60 +232,9 @@ func pgArrayToGoSlice(array *PgxArray[any]) any { for _, d := range dim { dims = append(dims, int(d.Length)) } - return CreateMultiDimSlice(dims, goSlice) - } - return goSlice -} - -func convertArrayToGoType(array *PgxArray[any]) []any { - if !isJsonArrayPgDataType(array.colDataType) { - if isPgUuidArray(array.colDataType) { - return convertBytesToUuidSlice(array.Elements) - } - if isPgXmlArray(array.colDataType) { - return convertBytesToStringSlice(array.Elements) - } - return array.Elements - } - - var newArray []any - for _, e := range array.Elements { - jsonBits, ok := e.([]byte) - if !ok { - newArray = append(newArray, e) - continue - } - - var js any - err := json.Unmarshal(jsonBits, &js) - if err != nil { - newArray = append(newArray, e) - } else { - newArray = append(newArray, js) - } - } - - return newArray -} - -func convertBytesToStringSlice(bytes []any) []any { - stringSlice := []any{} - for _, el := range bytes { - if bits, ok := el.([]byte); ok { - stringSlice = append(stringSlice, string(bits)) - } - } - return stringSlice -} - -func convertBytesToUuidSlice(uuids []any) []any { - uuidSlice := []any{} - for _, el := range uuids { - if id, ok := el.([16]uint8); ok { - uuidSlice = append(uuidSlice, uuid.UUID(id).String()) - } + return CreateMultiDimSlice(dims, array.Elements) } - return uuidSlice + return array.Elements } // converts flat slice to multi-dimensional slice diff --git a/internal/postgres/utils_test.go b/internal/postgres/utils_test.go index 1a67469e99..65a1ce8979 100644 --- a/internal/postgres/utils_test.go +++ b/internal/postgres/utils_test.go @@ -1,9 +1,10 @@ package postgres import ( + "database/sql" + "encoding/json" "testing" - "github.com/gofrs/uuid" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) @@ -168,39 +169,32 @@ func Test_FormatPgArrayLiteral(t *testing.T) { func Test_parsePgRowValues(t *testing.T) { t.Run("Multiple Columns", func(t *testing.T) { binaryData := []byte{0x01, 0x02, 0x03} - xmlData := []byte("value") + xmlStr := "value" uuidValue := "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11" + xmlVal := &sql.NullString{String: xmlStr, Valid: true} + jsonVal := &NullableJSON{RawMessage: json.RawMessage(`{"key": "value"}`), Valid: true} + values := []any{ "Hello", int64(42), true, nil, - []byte(`{"key": "value"}`), + jsonVal, &PgxArray[any]{ Array: pgtype.Array[any]{Elements: []any{1, 2, 3}}, colDataType: "_integer", }, binaryData, - xmlData, + xmlVal, uuidValue, } columnNames := []string{ "text_col", "int_col", "bool_col", "nil_col", "json_col", "array_col", "binary_col", "xml_col", "uuid_col", } - cTypes := []string{ - "text", - "integer", - "boolean", - "text", - "json", - "_integer", - "bytea", - "xml", - "uuid", - } - result := parsePgRowValues(values, columnNames, cTypes) + + result := parsePgRowValues(values, columnNames) expected := map[string]any{ "text_col": "Hello", "int_col": int64(42), @@ -209,18 +203,24 @@ func Test_parsePgRowValues(t *testing.T) { "json_col": map[string]any{"key": "value"}, "array_col": []any{1, 2, 3}, "binary_col": binaryData, - "xml_col": string(xmlData), // Assuming XML is converted to string + "xml_col": xmlStr, "uuid_col": uuidValue, } require.Equal(t, expected, result) }) t.Run("JSON Columns", func(t *testing.T) { - values := []any{[]byte(`"Hello"`), []byte(`true`), []byte(`null`), []byte(`42`), []byte(`{"items": ["book", "pen"], "count": 2, "in_stock": true}`), []byte(`[1,2,3]`)} + values := []any{ + &NullableJSON{RawMessage: json.RawMessage(`"Hello"`), Valid: true}, + &NullableJSON{RawMessage: json.RawMessage(`true`), Valid: true}, + &NullableJSON{Valid: false}, + &NullableJSON{RawMessage: json.RawMessage(`42`), Valid: true}, + &NullableJSON{RawMessage: json.RawMessage(`{"items": ["book", "pen"], "count": 2, "in_stock": true}`), Valid: true}, + &NullableJSON{RawMessage: json.RawMessage(`[1,2,3]`), Valid: true}, + } columnNames := []string{"text_col", "bool_col", "null_col", "int_col", "json_col", "array_col"} - cTypes := []string{"json", "json", "json", "json", "json", "json"} - result := parsePgRowValues(values, columnNames, cTypes) + result := parsePgRowValues(values, columnNames) expected := map[string]any{ "text_col": "Hello", @@ -236,10 +236,10 @@ func Test_parsePgRowValues(t *testing.T) { t.Run("Multiple Array Columns", func(t *testing.T) { binaryData1 := []byte{0x01, 0x02, 0x03} binaryData2 := []byte{0x04, 0x05, 0x06} - xmlData1 := []byte("value1") - xmlData2 := []byte("value2") - uuidValue1 := [16]uint8{0xa0, 0xee, 0xbc, 0x99, 0x9c, 0x0b, 0x4e, 0xf8, 0xbb, 0x6d, 0x6b, 0xb9, 0xbd, 0x38, 0x0a, 0x11} - uuidValue2 := [16]uint8{0xb0, 0xee, 0xbc, 0x99, 0x9c, 0x0b, 0x4e, 0xf8, 0xbb, 0x6d, 0x6b, 0xb9, 0xbd, 0x38, 0x0a, 0x22} + xmlData1 := "value1" + xmlData2 := "value2" + uuidValue1 := "160075f6-4d6e-4040-b287-bd43677464fa" + uuidValue2 := "5f4a4b96-a74e-4502-b05b-1d96fba90657" values := []any{ &PgxArray[any]{ @@ -255,7 +255,9 @@ func Test_parsePgRowValues(t *testing.T) { colDataType: "_boolean", }, &PgxArray[any]{ - Array: pgtype.Array[any]{Elements: []any{[]byte(`{"key": "value1"}`), []byte(`{"key": "value2"}`)}}, + Array: pgtype.Array[any]{Elements: []any{ + map[string]any{"key": "value1"}, map[string]any{"key": "value2"}, + }}, colDataType: "_json", }, &PgxArray[any]{ @@ -263,7 +265,10 @@ func Test_parsePgRowValues(t *testing.T) { colDataType: "_bytea", }, &PgxArray[any]{ - Array: pgtype.Array[any]{Elements: []any{xmlData1, xmlData2}}, + Array: pgtype.Array[any]{Elements: []any{ + xmlData1, + xmlData2, + }}, colDataType: "_xml", }, &PgxArray[any]{ @@ -281,12 +286,7 @@ func Test_parsePgRowValues(t *testing.T) { "binary_array", "xml_array", "uuid_array", "multidim_array", } - cTypes := []string{ - "_text", "_integer", "_boolean", "_json", - "_bytea", "_xml", "_uuid", "_integer[]", - } - - result := parsePgRowValues(values, columnNames, cTypes) + result := parsePgRowValues(values, columnNames) expected := map[string]any{ "text_array": []any{"Hello", "World"}, @@ -294,8 +294,8 @@ func Test_parsePgRowValues(t *testing.T) { "bool_array": []any{true, false}, "json_array": []any{map[string]any{"key": "value1"}, map[string]any{"key": "value2"}}, "binary_array": []any{binaryData1, binaryData2}, - "xml_array": []any{string(xmlData1), string(xmlData2)}, - "uuid_array": []any{uuid.UUID(uuidValue1).String(), uuid.UUID(uuidValue2).String()}, + "xml_array": []any{xmlData1, xmlData2}, + "uuid_array": []any{uuidValue1, uuidValue2}, "multidim_array": []any{[]any{1, 2}, []any{3, 4}}, } @@ -305,4 +305,72 @@ func Test_parsePgRowValues(t *testing.T) { require.ElementsMatch(t, actual, expectedArray) } }) + + t.Run("Null Values", func(t *testing.T) { + values := []any{ + &sql.NullString{Valid: false}, + &NullableJSON{Valid: false}, + } + columnNames := []string{"null_string", "null_json"} + + result := parsePgRowValues(values, columnNames) + + expected := map[string]any{ + "null_string": nil, + "null_json": nil, + } + require.Equal(t, expected, result) + }) +} + +func TestNullableJSON_Unmarshal(t *testing.T) { + tests := []struct { + name string + json NullableJSON + want any + wantErr bool + }{ + { + name: "Invalid JSON", + json: NullableJSON{Valid: false}, + want: nil, + wantErr: false, + }, + { + name: "Valid string", + json: NullableJSON{RawMessage: json.RawMessage(`"test"`), Valid: true}, + want: "test", + wantErr: false, + }, + { + name: "Valid number", + json: NullableJSON{RawMessage: json.RawMessage(`42`), Valid: true}, + want: float64(42), + wantErr: false, + }, + { + name: "Valid object", + json: NullableJSON{RawMessage: json.RawMessage(`{"key":"value"}`), Valid: true}, + want: map[string]any{"key": "value"}, + wantErr: false, + }, + { + name: "Invalid JSON content", + json: NullableJSON{RawMessage: json.RawMessage(`{invalid}`), Valid: true}, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.json.Unmarshal() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } } diff --git a/worker/pkg/benthos/sql/json_processor.go b/worker/pkg/benthos/sql/json_processor.go index 25270a444d..9d0cc1ea95 100644 --- a/worker/pkg/benthos/sql/json_processor.go +++ b/worker/pkg/benthos/sql/json_processor.go @@ -5,9 +5,9 @@ import ( "encoding/binary" "encoding/json" "strconv" - "strings" "github.com/lib/pq" + pgutil "github.com/nucleuscloud/neosync/internal/postgres" "github.com/warpstreamlabs/bento/public/service" ) @@ -83,7 +83,7 @@ func (p *jsonToSqlProcessor) transform(path string, root any) any { if !ok { return v } - if isPgArray(datatype) { + if pgutil.IsPgArrayColumnDataType(datatype) { pgarray, err := processPgArray(v, datatype) if err != nil { p.logger.Errorf("unable to process PG Array: %w", err) @@ -134,10 +134,6 @@ func processPgArray(bits []byte, datatype string) (any, error) { } } -func isPgArray(datatype string) bool { - return strings.HasSuffix(datatype, "[]") -} - // handles case where json strings are not quoted func getValidJson(jsonData []byte) ([]byte, error) { isValidJson := json.Valid(jsonData)