diff --git a/internal/benthos/benthos-builder/benthos-builder.go b/internal/benthos/benthos-builder/benthos-builder.go index e9640f6470..fa801008fb 100644 --- a/internal/benthos/benthos-builder/benthos-builder.go +++ b/internal/benthos/benthos-builder/benthos-builder.go @@ -102,7 +102,6 @@ func (b *BuilderProvider) registerStandardBuilders( connectionclient mgmtv1alpha1connect.ConnectionServiceClient, redisConfig *shared.RedisConfig, selectQueryBuilder bb_shared.SelectQueryMapBuilder, - rawSqlInsertMode bool, ) error { sourceConnectionType := bb_internal.GetConnectionType(sourceConnection) jobType := bb_internal.GetJobType(job) @@ -111,22 +110,17 @@ func (b *BuilderProvider) registerStandardBuilders( connectionTypes = append(connectionTypes, bb_internal.GetConnectionType(dest)) } - sqlSyncOptions := []bb_conns.SqlSyncOption{} - if rawSqlInsertMode { - sqlSyncOptions = append(sqlSyncOptions, bb_conns.WithRawInsertMode()) - } - if jobType == bb_internal.JobTypeSync { for _, connectionType := range connectionTypes { switch connectionType { case bb_internal.ConnectionTypePostgres: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.PostgresDriver, selectQueryBuilder, sqlSyncOptions...) + sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.PostgresDriver, selectQueryBuilder) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_internal.ConnectionTypeMysql: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MysqlDriver, selectQueryBuilder, sqlSyncOptions...) + sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MysqlDriver, selectQueryBuilder) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_internal.ConnectionTypeMssql: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MssqlDriver, selectQueryBuilder, sqlSyncOptions...) + sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MssqlDriver, selectQueryBuilder) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_internal.ConnectionTypeAwsS3: b.Register(bb_internal.JobTypeSync, bb_internal.ConnectionTypeAwsS3, bb_conns.NewAwsS3SyncBuilder()) @@ -217,7 +211,6 @@ type WorkerBenthosConfig struct { func NewWorkerBenthosConfigManager( config *WorkerBenthosConfig, ) (*BenthosConfigManager, error) { - rawInsertMode := false provider := NewBuilderProvider(config.Logger) err := provider.registerStandardBuilders( config.Job, @@ -228,7 +221,6 @@ func NewWorkerBenthosConfigManager( config.Connectionclient, config.RedisConfig, config.SelectQueryBuilder, - rawInsertMode, ) if err != nil { return nil, err @@ -269,7 +261,6 @@ type CliBenthosConfig struct { func NewCliBenthosConfigManager( config *CliBenthosConfig, ) (*BenthosConfigManager, error) { - rawInsertMode := true destinationProvider := NewBuilderProvider(config.Logger) err := destinationProvider.registerStandardBuilders( config.Job, @@ -280,7 +271,6 @@ func NewCliBenthosConfigManager( nil, config.RedisConfig, nil, - rawInsertMode, ) if err != nil { return nil, err diff --git a/internal/benthos/benthos-builder/builders/aws-s3.go b/internal/benthos/benthos-builder/builders/aws-s3.go index 38a1a2f9f8..ddfe3b2b32 100644 --- a/internal/benthos/benthos-builder/builders/aws-s3.go +++ b/internal/benthos/benthos-builder/builders/aws-s3.go @@ -75,17 +75,6 @@ func (b *awsS3SyncBuilder) BuildDestinationConfig(ctx context.Context, params *b storageClass = convertToS3StorageClass(destinationOpts.GetStorageClass()).String() } - processors := []*neosync_benthos.BatchProcessor{} - if isPooledSqlRawConfigured(benthosConfig.Config) { - processors = append(processors, &neosync_benthos.BatchProcessor{SqlToJson: &neosync_benthos.SqlToJsonConfig{}}) - } - - standardProcessors := []*neosync_benthos.BatchProcessor{ - {Archive: &neosync_benthos.ArchiveProcessor{Format: "lines"}}, - {Compress: &neosync_benthos.CompressProcessor{Algorithm: "gzip"}}, - } - processors = append(processors, standardProcessors...) - config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ Fallback: []neosync_benthos.Outputs{ { @@ -97,9 +86,13 @@ func (b *awsS3SyncBuilder) BuildDestinationConfig(ctx context.Context, params *b Path: strings.Join(s3pathpieces, "/"), ContentType: "application/gzip", Batching: &neosync_benthos.Batching{ - Count: batchingConfig.BatchCount, - Period: batchingConfig.BatchPeriod, - Processors: processors, + Count: batchingConfig.BatchCount, + Period: batchingConfig.BatchPeriod, + Processors: []*neosync_benthos.BatchProcessor{ + {NeosyncToJson: &neosync_benthos.NeosyncToJsonConfig{}}, + {Archive: &neosync_benthos.ArchiveProcessor{Format: "lines"}}, + {Compress: &neosync_benthos.CompressProcessor{Algorithm: "gzip"}}, + }, }, Credentials: buildBenthosS3Credentials(connAwsS3Config.Credentials), Region: connAwsS3Config.GetRegion(), @@ -120,12 +113,6 @@ func (b *awsS3SyncBuilder) BuildDestinationConfig(ctx context.Context, params *b return config, nil } -func isPooledSqlRawConfigured(cfg *neosync_benthos.BenthosConfig) bool { - return cfg != nil && - cfg.StreamConfig.Input != nil && - cfg.StreamConfig.Input.Inputs.PooledSqlRaw != nil -} - type S3StorageClass int const ( diff --git a/internal/benthos/benthos-builder/builders/benthos-builder_test.go b/internal/benthos/benthos-builder/builders/benthos-builder_test.go index 98ad28dc4a..94a41a00d5 100644 --- a/internal/benthos/benthos-builder/builders/benthos-builder_test.go +++ b/internal/benthos/benthos-builder/builders/benthos-builder_test.go @@ -511,12 +511,6 @@ func Test_convertUserDefinedFunctionConfig(t *testing.T) { require.Equal(t, resp, expected) } -func Test_buildPlainInsertArgs(t *testing.T) { - require.Empty(t, buildPlainInsertArgs(nil)) - require.Empty(t, buildPlainInsertArgs([]string{})) - require.Equal(t, buildPlainInsertArgs([]string{"foo", "bar", "baz"}), `root = [this."foo", this."bar", this."baz"]`) -} - func Test_buildPlainColumns(t *testing.T) { require.Empty(t, buildPlainColumns(nil)) require.Empty(t, buildPlainColumns([]*mgmtv1alpha1.JobMapping{})) diff --git a/internal/benthos/benthos-builder/builders/generate-ai.go b/internal/benthos/benthos-builder/builders/generate-ai.go index c9fd5511aa..9570f0f10f 100644 --- a/internal/benthos/benthos-builder/builders/generate-ai.go +++ b/internal/benthos/benthos-builder/builders/generate-ai.go @@ -19,10 +19,9 @@ import ( ) type generateAIBuilder struct { - transformerclient mgmtv1alpha1connect.TransformersServiceClient - sqlmanagerclient sqlmanager.SqlManagerClient - connectionclient mgmtv1alpha1connect.ConnectionServiceClient - aiGroupedTableCols map[string][]string + transformerclient mgmtv1alpha1connect.TransformersServiceClient + sqlmanagerclient sqlmanager.SqlManagerClient + connectionclient mgmtv1alpha1connect.ConnectionServiceClient } func NewGenerateAIBuilder( @@ -32,10 +31,9 @@ func NewGenerateAIBuilder( driver string, ) bb_internal.BenthosBuilder { return &generateAIBuilder{ - transformerclient: transformerclient, - sqlmanagerclient: sqlmanagerclient, - connectionclient: connectionclient, - aiGroupedTableCols: map[string][]string{}, + transformerclient: transformerclient, + sqlmanagerclient: sqlmanagerclient, + connectionclient: connectionclient, } } @@ -123,16 +121,6 @@ func (b *generateAIBuilder) BuildSourceConfigs(ctx context.Context, params *bb_i userBatchSize, ) - // builds a map of table key to columns for AI Generated schemas as they are calculated lazily instead of via job mappings - aiGroupedTableCols := map[string][]string{} - for _, agm := range mappings { - key := neosync_benthos.BuildBenthosTable(agm.Schema, agm.Table) - for _, col := range agm.Columns { - aiGroupedTableCols[key] = append(aiGroupedTableCols[key], col.Column) - } - } - b.aiGroupedTableCols = aiGroupedTableCols - return sourceResponses, nil } @@ -217,12 +205,6 @@ func (b *generateAIBuilder) BuildDestinationConfig(ctx context.Context, params * if err != nil { return nil, fmt.Errorf("unable to parse destination options: %w", err) } - tableKey := neosync_benthos.BuildBenthosTable(benthosConfig.TableSchema, benthosConfig.TableName) - - cols, ok := b.aiGroupedTableCols[tableKey] - if !ok { - return nil, fmt.Errorf("unable to find table columns for key (%s) when building destination connection", tableKey) - } processorConfigs := []neosync_benthos.ProcessorConfig{} for _, pc := range benthosConfig.Processors { @@ -244,12 +226,9 @@ func (b *generateAIBuilder) BuildDestinationConfig(ctx context.Context, params * ConnectionId: params.DestConnection.GetId(), Schema: benthosConfig.TableSchema, Table: benthosConfig.TableName, - Columns: cols, OnConflictDoNothing: destOpts.OnConflictDoNothing, TruncateOnRetry: destOpts.Truncate, - ArgsMapping: buildPlainInsertArgs(cols), - Batching: &neosync_benthos.Batching{ Period: destOpts.BatchPeriod, Count: destOpts.BatchCount, diff --git a/internal/benthos/benthos-builder/builders/generate.go b/internal/benthos/benthos-builder/builders/generate.go index 1d6d377199..b9c6ea1a3e 100644 --- a/internal/benthos/benthos-builder/builders/generate.go +++ b/internal/benthos/benthos-builder/builders/generate.go @@ -20,6 +20,7 @@ type generateBuilder struct { transformerclient mgmtv1alpha1connect.TransformersServiceClient sqlmanagerclient sqlmanager.SqlManagerClient connectionclient mgmtv1alpha1connect.ConnectionServiceClient + driver string } func NewGenerateBuilder( @@ -54,6 +55,7 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int return nil, fmt.Errorf("unable to create new sql db: %w", err) } defer db.Db().Close() + b.driver = db.Driver() groupedMappings := groupMappingsByTable(job.Mappings) groupedTableMapping := getTableMappingsMap(groupedMappings) @@ -179,6 +181,11 @@ func (b *generateBuilder) BuildDestinationConfig(ctx context.Context, params *bb processorConfigs = append(processorConfigs, *pc) } + sqlProcessor, err := getSqlBatchProcessors(b.driver, benthosConfig.Columns, map[string]string{}, benthosConfig.ColumnDefaultProperties) + if err != nil { + return nil, err + } + config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}) config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ Fallback: []neosync_benthos.Outputs{ @@ -193,18 +200,15 @@ func (b *generateBuilder) BuildDestinationConfig(ctx context.Context, params *bb PooledSqlInsert: &neosync_benthos.PooledSqlInsert{ ConnectionId: params.DestConnection.GetId(), - Schema: benthosConfig.TableSchema, - Table: benthosConfig.TableName, - Columns: benthosConfig.Columns, - ColumnDefaultProperties: benthosConfig.ColumnDefaultProperties, - OnConflictDoNothing: destOpts.OnConflictDoNothing, - TruncateOnRetry: destOpts.Truncate, - - ArgsMapping: buildPlainInsertArgs(benthosConfig.Columns), + Schema: benthosConfig.TableSchema, + Table: benthosConfig.TableName, + OnConflictDoNothing: destOpts.OnConflictDoNothing, + TruncateOnRetry: destOpts.Truncate, Batching: &neosync_benthos.Batching{ - Period: destOpts.BatchPeriod, - Count: destOpts.BatchCount, + Period: destOpts.BatchPeriod, + Count: destOpts.BatchCount, + Processors: []*neosync_benthos.BatchProcessor{sqlProcessor}, }, MaxInFlight: int(destOpts.MaxInFlight), }, diff --git a/internal/benthos/benthos-builder/builders/sql-util.go b/internal/benthos/benthos-builder/builders/sql-util.go index c973e182c2..d09654fb92 100644 --- a/internal/benthos/benthos-builder/builders/sql-util.go +++ b/internal/benthos/benthos-builder/builders/sql-util.go @@ -142,17 +142,6 @@ func getMapValuesCount[K comparable, V any](m map[K][]V) int { return count } -func buildPlainInsertArgs(cols []string) string { - if len(cols) == 0 { - return "" - } - pieces := make([]string, len(cols)) - for idx := range cols { - pieces[idx] = fmt.Sprintf("this.%q", cols[idx]) - } - return fmt.Sprintf("root = [%s]", strings.Join(pieces, ", ")) -} - func buildPlainColumns(mappings []*mgmtv1alpha1.JobMapping) []string { columns := make([]string, len(mappings)) for idx := range mappings { @@ -439,6 +428,7 @@ func getColumnDefaultProperties( if !ok { return nil, fmt.Errorf("transformer missing for column: %s", cName) } + var hasDefaultTransformer bool if jmTransformer != nil && isDefaultJobMappingTransformer(jmTransformer) { hasDefaultTransformer = true @@ -906,3 +896,25 @@ func cleanPostgresType(dataType string) string { } return strings.TrimSpace(dataType[:parenIndex]) } + +func shouldOverrideColumnDefault(columnDefaults map[string]*neosync_benthos.ColumnDefaultProperties) bool { + for _, cd := range columnDefaults { + if cd != nil && !cd.HasDefaultTransformer && cd.NeedsOverride { + return true + } + } + return false +} + +func getSqlBatchProcessors(driver string, columns []string, columnDataTypes map[string]string, columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties) (*neosync_benthos.BatchProcessor, error) { + switch driver { + case sqlmanager_shared.PostgresDriver: + return &neosync_benthos.BatchProcessor{NeosyncToPgx: &neosync_benthos.NeosyncToPgxConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + case sqlmanager_shared.MysqlDriver: + return &neosync_benthos.BatchProcessor{NeosyncToMysql: &neosync_benthos.NeosyncToMysqlConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + case sqlmanager_shared.MssqlDriver: + return &neosync_benthos.BatchProcessor{NeosyncToMssql: &neosync_benthos.NeosyncToMssqlConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + default: + return nil, fmt.Errorf("unsupported driver %q when attempting to get sql batch processors", driver) + } +} diff --git a/internal/benthos/benthos-builder/builders/sql.go b/internal/benthos/benthos-builder/builders/sql.go index ebfc645748..7524542c0e 100644 --- a/internal/benthos/benthos-builder/builders/sql.go +++ b/internal/benthos/benthos-builder/builders/sql.go @@ -28,7 +28,6 @@ type sqlSyncBuilder struct { redisConfig *shared.RedisConfig driver string selectQueryBuilder bb_shared.SelectQueryMapBuilder - options *SqlSyncOptions // reverse of table dependency // map of foreign key to source table + column @@ -40,37 +39,19 @@ type sqlSyncBuilder struct { isNotForeignKeySafeSubsetMap map[string]map[tabledependency.RunType]bool // schema.table -> true if the query could return rows that violate foreign key constraints } -type SqlSyncOption func(*SqlSyncOptions) -type SqlSyncOptions struct { - rawInsertMode bool -} - -// WithRawInsertMode inserts data as is -func WithRawInsertMode() SqlSyncOption { - return func(opts *SqlSyncOptions) { - opts.rawInsertMode = true - } -} - func NewSqlSyncBuilder( transformerclient mgmtv1alpha1connect.TransformersServiceClient, sqlmanagerclient sqlmanager.SqlManagerClient, redisConfig *shared.RedisConfig, databaseDriver string, selectQueryBuilder bb_shared.SelectQueryMapBuilder, - opts ...SqlSyncOption, ) bb_internal.BenthosBuilder { - options := &SqlSyncOptions{} - for _, opt := range opts { - opt(options) - } return &sqlSyncBuilder{ transformerclient: transformerclient, sqlmanagerclient: sqlmanagerclient, redisConfig: redisConfig, driver: databaseDriver, selectQueryBuilder: selectQueryBuilder, - options: options, isNotForeignKeySafeSubsetMap: map[string]map[tabledependency.RunType]bool{}, } } @@ -340,8 +321,6 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}) if benthosConfig.RunType == tabledependency.RunTypeUpdate { - args := benthosConfig.Columns - args = append(args, benthosConfig.PrimaryKeys...) config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ Fallback: []neosync_benthos.Outputs{ { @@ -354,7 +333,6 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ SkipForeignKeyViolations: skipForeignKeyViolations, MaxInFlight: int(destOpts.MaxInFlight), WhereColumns: benthosConfig.PrimaryKeys, - ArgsMapping: buildPlainInsertArgs(args), Batching: &neosync_benthos.Batching{ Period: destOpts.BatchPeriod, @@ -402,24 +380,17 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ } } - columnTypes := []string{} // use map going forward columnDataTypes := map[string]string{} for _, c := range benthosConfig.Columns { colType, ok := colInfoMap[c] if ok { columnDataTypes[c] = colType.DataType - columnTypes = append(columnTypes, colType.DataType) - } else { - columnTypes = append(columnTypes, "") } } - batchProcessors := []*neosync_benthos.BatchProcessor{} - if benthosConfig.Config.Input.Inputs.NeosyncConnectionData != nil { - batchProcessors = append(batchProcessors, &neosync_benthos.BatchProcessor{JsonToSql: &neosync_benthos.JsonToSqlConfig{ColumnDataTypes: columnDataTypes}}) - } - if b.driver == sqlmanager_shared.PostgresDriver || strings.EqualFold(b.driver, "postgres") { - batchProcessors = append(batchProcessors, &neosync_benthos.BatchProcessor{NeosyncToPgx: &neosync_benthos.NeosyncToPgxConfig{}}) + sqlProcessor, err := getSqlBatchProcessors(b.driver, benthosConfig.Columns, columnDataTypes, columnDefaultProperties) + if err != nil { + return nil, err } prefix, suffix := getInsertPrefixAndSuffix(b.driver, benthosConfig.TableSchema, benthosConfig.TableName, columnDefaultProperties) @@ -429,23 +400,19 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ PooledSqlInsert: &neosync_benthos.PooledSqlInsert{ ConnectionId: params.DestConnection.GetId(), - Schema: benthosConfig.TableSchema, - Table: benthosConfig.TableName, - Columns: benthosConfig.Columns, - ColumnsDataTypes: columnTypes, - ColumnDefaultProperties: columnDefaultProperties, - OnConflictDoNothing: destOpts.OnConflictDoNothing, - SkipForeignKeyViolations: skipForeignKeyViolations, - RawInsertMode: b.options.rawInsertMode, - TruncateOnRetry: destOpts.Truncate, - ArgsMapping: buildPlainInsertArgs(benthosConfig.Columns), - Prefix: prefix, - Suffix: suffix, + Schema: benthosConfig.TableSchema, + Table: benthosConfig.TableName, + OnConflictDoNothing: destOpts.OnConflictDoNothing, + SkipForeignKeyViolations: skipForeignKeyViolations, + ShouldOverrideColumnDefault: shouldOverrideColumnDefault(columnDefaultProperties), + TruncateOnRetry: destOpts.Truncate, + Prefix: prefix, + Suffix: suffix, Batching: &neosync_benthos.Batching{ Period: destOpts.BatchPeriod, Count: destOpts.BatchCount, - Processors: batchProcessors, + Processors: []*neosync_benthos.BatchProcessor{sqlProcessor}, }, MaxInFlight: int(destOpts.MaxInFlight), }, diff --git a/worker/pkg/benthos/config.go b/worker/pkg/benthos/config.go index b8061c56fd..e5dcfc6b73 100644 --- a/worker/pkg/benthos/config.go +++ b/worker/pkg/benthos/config.go @@ -362,7 +362,6 @@ type PooledSqlUpdate struct { Columns []string `json:"columns" yaml:"columns"` WhereColumns []string `json:"where_columns" yaml:"where_columns"` SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` - ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` MaxRetryAttempts *uint `json:"max_retry_attempts,omitempty" yaml:"max_retry_attempts,omitempty"` RetryAttemptDelay *string `json:"retry_attempt_delay,omitempty" yaml:"retry_attempt_delay,omitempty"` @@ -376,23 +375,19 @@ type ColumnDefaultProperties struct { } type PooledSqlInsert struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` - Columns []string `json:"columns" yaml:"columns"` - ColumnsDataTypes []string `json:"column_data_types" yaml:"column_data_types"` - ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` - OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"` - TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"` - SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` - RawInsertMode bool `json:"raw_insert_mode" yaml:"raw_insert_mode"` - ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"` - Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"` - MaxRetryAttempts *uint `json:"max_retry_attempts,omitempty" yaml:"max_retry_attempts,omitempty"` - RetryAttemptDelay *string `json:"retry_attempt_delay,omitempty" yaml:"retry_attempt_delay,omitempty"` - MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Schema string `json:"schema" yaml:"schema"` + Table string `json:"table" yaml:"table"` + OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"` + TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"` + SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` + ShouldOverrideColumnDefault bool `json:"should_override_column_default" yaml:"should_override_column_default"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"` + Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"` + MaxRetryAttempts *uint `json:"max_retry_attempts,omitempty" yaml:"max_retry_attempts,omitempty"` + RetryAttemptDelay *string `json:"retry_attempt_delay,omitempty" yaml:"retry_attempt_delay,omitempty"` + MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` } type SqlInsert struct { @@ -456,21 +451,33 @@ type Batching struct { } type BatchProcessor struct { - Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"` - Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"` - SqlToJson *SqlToJsonConfig `json:"sql_to_json,omitempty" yaml:"sql_to_json,omitempty"` - JsonToSql *JsonToSqlConfig `json:"json_to_sql,omitempty" yaml:"json_to_sql,omitempty"` - NeosyncToPgx *NeosyncToPgxConfig `json:"neosync_to_pgx,omitempty" yaml:"neosync_to_pgx,omitempty"` + Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"` + Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"` + NeosyncToJson *NeosyncToJsonConfig `json:"neosync_to_json,omitempty" yaml:"neosync_to_json,omitempty"` + NeosyncToPgx *NeosyncToPgxConfig `json:"neosync_to_pgx,omitempty" yaml:"neosync_to_pgx,omitempty"` + NeosyncToMysql *NeosyncToMysqlConfig `json:"neosync_to_mysql,omitempty" yaml:"neosync_to_mysql,omitempty"` + NeosyncToMssql *NeosyncToMssqlConfig `json:"neosync_to_mssql,omitempty" yaml:"neosync_to_mssql,omitempty"` } type NeosyncToPgxConfig struct { + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` } -type JsonToSqlConfig struct { - ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` +type NeosyncToMysqlConfig struct { + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` } -type SqlToJsonConfig struct{} +type NeosyncToMssqlConfig struct { + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` +} + +type NeosyncToJsonConfig struct{} type ArchiveProcessor struct { Format string `json:"format" yaml:"format"` diff --git a/worker/pkg/benthos/environment/environment.go b/worker/pkg/benthos/environment/environment.go index cb2d068d8c..a451300262 100644 --- a/worker/pkg/benthos/environment/environment.go +++ b/worker/pkg/benthos/environment/environment.go @@ -169,19 +169,24 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O return nil, fmt.Errorf("unable to register default mapping processor to benthos instance: %w", err) } - err = neosync_benthos_json.RegisterSqlToJsonProcessor(env) + err = neosync_benthos_json.RegisterNeosyncToJsonProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register SQL to JSON processor to benthos instance: %w", err) + return nil, fmt.Errorf("unable to register Neosync to JSON processor to benthos instance: %w", err) } - err = neosync_benthos_sql.RegisterJsonToSqlProcessor(env) + err = neosync_benthos_sql.RegisterNeosyncToPgxProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register JSON to SQL processor to benthos instance: %w", err) + return nil, fmt.Errorf("unable to register Neosync to PGX processor to benthos instance: %w", err) } - err = neosync_benthos_sql.RegisterNeosyncToPgxProcessor(env) + err = neosync_benthos_sql.RegisterNeosyncToMysqlProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register Neosync to PGX processor to benthos instance: %w", err) + return nil, fmt.Errorf("unable to register Neosync to MYSQL processor to benthos instance: %w", err) + } + + err = neosync_benthos_sql.RegisterNeosyncToMssqlProcessor(env) + if err != nil { + return nil, fmt.Errorf("unable to register Neosync to MSSQL processor to benthos instance: %w", err) } if config.blobEnv != nil { diff --git a/worker/pkg/benthos/json/sql_processor.go b/worker/pkg/benthos/json/processor_neosync_json.go similarity index 67% rename from worker/pkg/benthos/json/sql_processor.go rename to worker/pkg/benthos/json/processor_neosync_json.go index 4115210652..b1bc0a4682 100644 --- a/worker/pkg/benthos/json/sql_processor.go +++ b/worker/pkg/benthos/json/processor_neosync_json.go @@ -8,31 +8,31 @@ import ( "github.com/warpstreamlabs/bento/public/service" ) -func sqlToJsonProcessorConfig() *service.ConfigSpec { +func neosyncToJsonProcessorConfig() *service.ConfigSpec { return service.NewConfigSpec() } -func RegisterSqlToJsonProcessor(env *service.Environment) error { +func RegisterNeosyncToJsonProcessor(env *service.Environment) error { return env.RegisterBatchProcessor( - "sql_to_json", - sqlToJsonProcessorConfig(), + "neosync_to_json", + neosyncToJsonProcessorConfig(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchProcessor, error) { - proc := newMysqlToJsonProcessor(conf, mgr) + proc := newNeosyncToJsonProcessor(conf, mgr) return proc, nil }) } -type sqlToJsonProcessor struct { +type neosyncToJsonProcessor struct { logger *service.Logger } -func newMysqlToJsonProcessor(_ *service.ParsedConfig, mgr *service.Resources) *sqlToJsonProcessor { - return &sqlToJsonProcessor{ +func newNeosyncToJsonProcessor(_ *service.ParsedConfig, mgr *service.Resources) *neosyncToJsonProcessor { + return &neosyncToJsonProcessor{ logger: mgr.Logger(), } } -func (m *sqlToJsonProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (m *neosyncToJsonProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() @@ -51,7 +51,7 @@ func (m *sqlToJsonProcessor) ProcessBatch(ctx context.Context, batch service.Mes return []service.MessageBatch{newBatch}, nil } -func (m *sqlToJsonProcessor) Close(context.Context) error { +func (m *neosyncToJsonProcessor) Close(context.Context) error { return nil } @@ -74,6 +74,7 @@ func transform(root any) any { return v.Format(time.DateTime) case []uint8: return string(v) + // TODO this should be neosync bit type case *sqlscanners.BitString: return v.String() default: diff --git a/worker/pkg/benthos/sql/json_processor.go b/worker/pkg/benthos/sql/json_processor.go deleted file mode 100644 index a05f32fda3..0000000000 --- a/worker/pkg/benthos/sql/json_processor.go +++ /dev/null @@ -1,177 +0,0 @@ -package neosync_benthos_sql - -import ( - "context" - "encoding/binary" - "encoding/json" - "strconv" - - "github.com/lib/pq" - pgutil "github.com/nucleuscloud/neosync/internal/postgres" - "github.com/warpstreamlabs/bento/public/service" -) - -func jsonToSqlProcessorConfig() *service.ConfigSpec { - return service.NewConfigSpec().Field(service.NewStringMapField("column_data_types")) -} - -func RegisterJsonToSqlProcessor(env *service.Environment) error { - return env.RegisterBatchProcessor( - "json_to_sql", - jsonToSqlProcessorConfig(), - func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchProcessor, error) { - proc, err := newJsonToSqlProcessor(conf, mgr) - if err != nil { - return nil, err - } - return proc, nil - }) -} - -type jsonToSqlProcessor struct { - logger *service.Logger - columnDataTypes map[string]string // column name to datatype -} - -func newJsonToSqlProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*jsonToSqlProcessor, error) { - columnDataTypes, err := conf.FieldStringMap("column_data_types") - if err != nil { - return nil, err - } - return &jsonToSqlProcessor{ - logger: mgr.Logger(), - columnDataTypes: columnDataTypes, - }, nil -} - -func (p *jsonToSqlProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { - newBatch := make(service.MessageBatch, 0, len(batch)) - for _, msg := range batch { - root, err := msg.AsStructuredMut() - if err != nil { - return nil, err - } - newRoot := p.transform("", root) - newMsg := msg.Copy() - newMsg.SetStructured(newRoot) - newBatch = append(newBatch, newMsg) - } - - if len(newBatch) == 0 { - return nil, nil - } - return []service.MessageBatch{newBatch}, nil -} - -func (m *jsonToSqlProcessor) Close(context.Context) error { - return nil -} - -func (p *jsonToSqlProcessor) transform(path string, root any) any { - switch v := root.(type) { - case map[string]any: - newMap := make(map[string]any) - for k, v2 := range v { - newValue := p.transform(k, v2) - newMap[k] = newValue - } - return newMap - case nil: - return v - case []byte: - datatype, ok := p.columnDataTypes[path] - if !ok { - return v - } - // TODO move to pgx processor - if pgutil.IsPgArrayColumnDataType(datatype) { - pgarray, err := processPgArray(v, datatype) - if err != nil { - p.logger.Errorf("unable to process PG Array: %w", err) - return v - } - return pgarray - } - switch datatype { - case "bit": - bit, err := convertStringToBit(string(v)) - if err != nil { - p.logger.Errorf("unable to convert bit string to SQL bit []byte: %w", err) - return v - } - return bit - case "json", "jsonb": - validJson, err := getValidJson(v) - if err != nil { - p.logger.Errorf("unable to get valid json: %w", err) - return v - } - return validJson - case "money", "uuid", "time with time zone", "timestamp with time zone": - // Convert UUID []byte to string before inserting since postgres driver stores uuid bytes in different order - return string(v) - } - return v - default: - return v - } -} - -func processPgArray(bits []byte, datatype string) (any, error) { - var pgarray []any - err := json.Unmarshal(bits, &pgarray) - if err != nil { - return nil, err - } - switch datatype { - case "json[]", "jsonb[]": - jsonArray, err := stringifyJsonArray(pgarray) - if err != nil { - return nil, err - } - return pq.Array(jsonArray), nil - default: - return pq.Array(pgarray), nil - } -} - -// handles case where json strings are not quoted -func getValidJson(jsonData []byte) ([]byte, error) { - isValidJson := json.Valid(jsonData) - if isValidJson { - return jsonData, nil - } - - quotedData, err := json.Marshal(string(jsonData)) - if err != nil { - return nil, err - } - return quotedData, nil -} - -func stringifyJsonArray(pgarray []any) ([]string, error) { - jsonArray := make([]string, len(pgarray)) - for i, item := range pgarray { - bytes, err := json.Marshal(item) - if err != nil { - return nil, err - } - jsonArray[i] = string(bytes) - } - return jsonArray, nil -} - -func convertStringToBit(bitString string) ([]byte, error) { - val, err := strconv.ParseUint(bitString, 2, len(bitString)) - if err != nil { - return nil, err - } - - // Always allocate 8 bytes for PutUint64 - bytes := make([]byte, 8) - binary.BigEndian.PutUint64(bytes, val) - - // Calculate actual needed bytes and return only those - neededBytes := (len(bitString) + 7) / 8 - return bytes[len(bytes)-neededBytes:], nil -} diff --git a/worker/pkg/benthos/sql/json_processor_test.go b/worker/pkg/benthos/sql/json_processor_test.go deleted file mode 100644 index 7de52328fc..0000000000 --- a/worker/pkg/benthos/sql/json_processor_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package neosync_benthos_sql - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_convertStringToBit(t *testing.T) { - t.Run("8 bits", func(t *testing.T) { - got, err := convertStringToBit("10101010") - require.NoError(t, err) - expected := []byte{170} - require.Equalf(t, expected, got, "got %v, want %v", got, expected) - }) - - t.Run("1 bit", func(t *testing.T) { - got, err := convertStringToBit("1") - require.NoError(t, err) - expected := []byte{1} - require.Equalf(t, expected, got, "got %v, want %v", got, expected) - }) - - t.Run("16 bits", func(t *testing.T) { - got, err := convertStringToBit("1010101010101010") - require.NoError(t, err) - expected := []byte{170, 170} - require.Equalf(t, expected, got, "got %v, want %v", got, expected) - }) - - t.Run("24 bits", func(t *testing.T) { - got, err := convertStringToBit("101010101111111100000000") - require.NoError(t, err) - expected := []byte{170, 255, 0} - require.Equalf(t, expected, got, "got %v, want %v", got, expected) - }) - - t.Run("invalid binary string", func(t *testing.T) { - _, err := convertStringToBit("102") - require.Error(t, err) - }) - - t.Run("empty string", func(t *testing.T) { - _, err := convertStringToBit("") - require.Error(t, err) - }) -} diff --git a/worker/pkg/benthos/sql/output_sql_insert.go b/worker/pkg/benthos/sql/output_sql_insert.go index 68e806f771..0fa36ffdba 100644 --- a/worker/pkg/benthos/sql/output_sql_insert.go +++ b/worker/pkg/benthos/sql/output_sql_insert.go @@ -2,7 +2,6 @@ package neosync_benthos_sql import ( "context" - "encoding/json" "fmt" "log/slog" "sync" @@ -14,7 +13,6 @@ import ( mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" querybuilder "github.com/nucleuscloud/neosync/worker/pkg/query-builder" - "github.com/warpstreamlabs/bento/public/bloblang" "github.com/warpstreamlabs/bento/public/service" ) @@ -23,14 +21,10 @@ func sqlInsertOutputSpec() *service.ConfigSpec { Field(service.NewStringField("connection_id")). Field(service.NewStringField("schema")). Field(service.NewStringField("table")). - Field(service.NewStringListField("columns")). - Field(service.NewStringListField("column_data_types")). - Field(service.NewAnyMapField("column_default_properties")). - Field(service.NewBloblangField("args_mapping").Optional()). Field(service.NewBoolField("on_conflict_do_nothing").Optional().Default(false)). Field(service.NewBoolField("skip_foreign_key_violations").Optional().Default(false)). - Field(service.NewBoolField("raw_insert_mode").Optional().Default(false)). Field(service.NewBoolField("truncate_on_retry").Optional().Default(false)). + Field(service.NewBoolField("should_override_column_default").Optional().Default(false)). Field(service.NewIntField("max_in_flight").Default(64)). Field(service.NewBatchPolicyField("batching")). Field(service.NewStringField("prefix").Optional()). @@ -73,21 +67,18 @@ type pooledInsertOutput struct { logger *service.Logger slogger *slog.Logger - schema string - table string - columns []string - columnDataTypes []string - columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties - onConflictDoNothing bool - skipForeignKeyViolations bool - rawInsertMode bool - truncateOnRetry bool - prefix *string - suffix *string - - argsMapping *bloblang.Executor - shutSig *shutdown.Signaller - isRetry bool + schema string + table string + columns []string + onConflictDoNothing bool + skipForeignKeyViolations bool + truncateOnRetry bool + shouldOverrideColumnDefault bool + prefix *string + suffix *string + + shutSig *shutdown.Signaller + isRetry bool maxRetryAttempts uint retryDelay time.Duration @@ -109,40 +100,6 @@ func newInsertOutput(conf *service.ParsedConfig, mgr *service.Resources, provide return nil, err } - columns, err := conf.FieldStringList("columns") - if err != nil { - return nil, err - } - - columnDataTypes, err := conf.FieldStringList("column_data_types") - if err != nil { - return nil, err - } - - columnDefaultPropertiesConfig, err := conf.FieldAnyMap("column_default_properties") - if err != nil { - return nil, err - } - - columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{} - for key, properties := range columnDefaultPropertiesConfig { - props, err := properties.FieldAny() - if err != nil { - return nil, err - } - jsonData, err := json.Marshal(props) - if err != nil { - return nil, fmt.Errorf("failed to marshal properties for key %s: %w", key, err) - } - - var colDefaults neosync_benthos.ColumnDefaultProperties - if err := json.Unmarshal(jsonData, &colDefaults); err != nil { - return nil, fmt.Errorf("failed to unmarshal properties for key %s: %w", key, err) - } - - columnDefaultProperties[key] = &colDefaults - } - onConflictDoNothing, err := conf.FieldBool("on_conflict_do_nothing") if err != nil { return nil, err @@ -153,12 +110,12 @@ func newInsertOutput(conf *service.ParsedConfig, mgr *service.Resources, provide return nil, err } - rawInsertMode, err := conf.FieldBool("raw_insert_mode") + truncateOnRetry, err := conf.FieldBool("truncate_on_retry") if err != nil { return nil, err } - truncateOnRetry, err := conf.FieldBool("truncate_on_retry") + shouldOverrideColumnDefault, err := conf.FieldBool("should_override_column_default") if err != nil { return nil, err } @@ -181,13 +138,6 @@ func newInsertOutput(conf *service.ParsedConfig, mgr *service.Resources, provide suffix = &suffixStr } - var argsMapping *bloblang.Executor - if conf.Contains("args_mapping") { - if argsMapping, err = conf.FieldBloblang("args_mapping"); err != nil { - return nil, err - } - } - retryAttemptsConf, err := conf.FieldInt("max_retry_attempts") if err != nil { return nil, err @@ -210,27 +160,23 @@ func newInsertOutput(conf *service.ParsedConfig, mgr *service.Resources, provide } output := &pooledInsertOutput{ - connectionId: connectionId, - driver: driver, - logger: mgr.Logger(), - slogger: logger, - shutSig: shutdown.NewSignaller(), - argsMapping: argsMapping, - provider: provider, - schema: schema, - table: table, - columns: columns, - columnDataTypes: columnDataTypes, - columnDefaultProperties: columnDefaultProperties, - onConflictDoNothing: onConflictDoNothing, - skipForeignKeyViolations: skipForeignKeyViolations, - rawInsertMode: rawInsertMode, - truncateOnRetry: truncateOnRetry, - prefix: prefix, - suffix: suffix, - isRetry: isRetry, - maxRetryAttempts: retryAttempts, - retryDelay: retryDelay, + connectionId: connectionId, + driver: driver, + logger: mgr.Logger(), + slogger: logger, + shutSig: shutdown.NewSignaller(), + provider: provider, + schema: schema, + table: table, + onConflictDoNothing: onConflictDoNothing, + skipForeignKeyViolations: skipForeignKeyViolations, + truncateOnRetry: truncateOnRetry, + shouldOverrideColumnDefault: shouldOverrideColumnDefault, + prefix: prefix, + suffix: suffix, + isRetry: isRetry, + maxRetryAttempts: retryAttempts, + retryDelay: retryDelay, } return output, nil } @@ -283,47 +229,27 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa return nil } - var executor *service.MessageBatchBloblangExecutor - if s.argsMapping != nil { - executor = batch.BloblangExecutor(s.argsMapping) + columnMap := map[string]struct{}{} + for _, col := range s.columns { + columnMap[col] = struct{}{} } - rows := [][]any{} - for i := range batch { - if s.argsMapping == nil { - continue - } - resMsg, err := executor.Query(i) - if err != nil { - return err - } - - iargs, err := resMsg.AsStructured() - if err != nil { - return err - } - - args, ok := iargs.([]any) + rows := []map[string]any{} + for _, msg := range batch { + m, _ := msg.AsStructured() + msgMap, ok := m.(map[string]any) if !ok { - return fmt.Errorf("mapping returned non-array result: %T", iargs) + return fmt.Errorf("message returned non-map result: %T", msgMap) } - rows = append(rows, args) + rows = append(rows, msgMap) } - - // keep same index and order of columns slice - columnDefaults := make([]*neosync_benthos.ColumnDefaultProperties, len(s.columns)) - for idx, cName := range s.columns { - defaults, ok := s.columnDefaultProperties[cName] - if !ok { - defaults = &neosync_benthos.ColumnDefaultProperties{} - } - columnDefaults[idx] = defaults + if len(rows) == 0 { + s.logger.Debug("no rows to insert") + return nil } options := []querybuilder.InsertOption{ - querybuilder.WithColumnDataTypes(s.columnDataTypes), - querybuilder.WithColumnDefaults(columnDefaults), querybuilder.WithPrefix(s.prefix), querybuilder.WithSuffix(s.suffix), } @@ -331,15 +257,14 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa if s.onConflictDoNothing { options = append(options, querybuilder.WithOnConflictDoNothing()) } - if s.rawInsertMode { - options = append(options, querybuilder.WithRawInsertMode()) + if s.shouldOverrideColumnDefault { + options = append(options, querybuilder.WithShouldOverrideColumnDefault()) } builder, err := querybuilder.GetInsertBuilder( s.slogger, s.driver, s.schema, s.table, - s.columns, options..., ) if err != nil { @@ -357,7 +282,7 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa return err } - err = s.RetryInsertRowByRow(ctx, builder, insertQuery, s.columns, rows, columnDefaults) + err = s.RetryInsertRowByRow(ctx, builder, rows) if err != nil { return err } @@ -368,20 +293,16 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa func (s *pooledInsertOutput) RetryInsertRowByRow( ctx context.Context, builder querybuilder.InsertQueryBuilder, - insertQuery string, - columns []string, - rows [][]any, - columnDefaults []*neosync_benthos.ColumnDefaultProperties, + rows []map[string]any, ) error { fkErrorCount := 0 insertCount := 0 - preparedInsert, err := builder.BuildPreparedInsertQuerySingleRow() - if err != nil { - return err - } - args := builder.BuildPreparedInsertArgs(rows) - for _, row := range args { - err = s.execWithRetry(ctx, preparedInsert, row) + for _, row := range rows { + insertQuery, args, err := builder.BuildInsertQuery([]map[string]any{row}) + if err != nil { + return err + } + err = s.execWithRetry(ctx, insertQuery, args) if err != nil && neosync_benthos.IsForeignKeyViolationError(err.Error()) { fkErrorCount++ } else if err != nil && !neosync_benthos.IsForeignKeyViolationError(err.Error()) { diff --git a/worker/pkg/benthos/sql/output_sql_update.go b/worker/pkg/benthos/sql/output_sql_update.go index b5451a6966..2069711769 100644 --- a/worker/pkg/benthos/sql/output_sql_update.go +++ b/worker/pkg/benthos/sql/output_sql_update.go @@ -13,7 +13,6 @@ import ( mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" querybuilder "github.com/nucleuscloud/neosync/worker/pkg/query-builder" - "github.com/warpstreamlabs/bento/public/bloblang" "github.com/warpstreamlabs/bento/public/service" ) @@ -41,7 +40,6 @@ func sqlUpdateOutputSpec() *service.ConfigSpec { Field(service.NewStringListField("columns")). Field(service.NewStringListField("where_columns")). Field(service.NewBoolField("skip_foreign_key_violations").Optional().Default(false)). - Field(service.NewBloblangField("args_mapping").Optional()). Field(service.NewIntField("max_in_flight").Default(64)). Field(service.NewBatchPolicyField("batching")). Field(service.NewStringField("max_retry_attempts").Default(3)). @@ -87,8 +85,7 @@ type pooledUpdateOutput struct { whereCols []string skipForeignKeyViolations bool - argsMapping *bloblang.Executor - shutSig *shutdown.Signaller + shutSig *shutdown.Signaller maxRetryAttempts uint retryDelay time.Duration @@ -125,13 +122,6 @@ func newUpdateOutput(conf *service.ParsedConfig, mgr *service.Resources, provide return nil, err } - var argsMapping *bloblang.Executor - if conf.Contains("args_mapping") { - if argsMapping, err = conf.FieldBloblang("args_mapping"); err != nil { - return nil, err - } - } - retryAttemptsConf, err := conf.FieldInt("max_retry_attempts") if err != nil { return nil, err @@ -159,7 +149,6 @@ func newUpdateOutput(conf *service.ParsedConfig, mgr *service.Resources, provide connectionId: connectionId, logger: mgr.Logger(), shutSig: shutdown.NewSignaller(), - argsMapping: argsMapping, provider: provider, schema: schema, table: table, @@ -208,40 +197,16 @@ func (s *pooledUpdateOutput) WriteBatch(ctx context.Context, batch service.Messa return nil } - var executor *service.MessageBatchBloblangExecutor - if s.argsMapping != nil { - executor = batch.BloblangExecutor(s.argsMapping) - } - - for i := range batch { - if s.argsMapping == nil { - continue - } - resMsg, err := executor.Query(i) - if err != nil { - return err - } + for _, msg := range batch { + m, _ := msg.AsStructured() - iargs, err := resMsg.AsStructured() - if err != nil { - return err - } - - args, ok := iargs.([]any) + // msgMap has all the table columns and values not just the columns we are updating + msgMap, ok := m.(map[string]any) if !ok { - return fmt.Errorf("mapping returned non-array result: %T", iargs) + return fmt.Errorf("message returned non-map result: %T", msgMap) } - allCols := []string{} - allCols = append(allCols, s.columns...) - allCols = append(allCols, s.whereCols...) - - colValMap := map[string]any{} - for idx, col := range allCols { - colValMap[col] = args[idx] - } - - query, err := querybuilder.BuildUpdateQuery(s.driver, s.schema, s.table, s.columns, s.whereCols, colValMap) + query, err := querybuilder.BuildUpdateQuery(s.driver, s.schema, s.table, s.columns, s.whereCols, msgMap) if err != nil { return err } @@ -251,6 +216,7 @@ func (s *pooledUpdateOutput) WriteBatch(ctx context.Context, batch service.Messa } } } + return nil } diff --git a/worker/pkg/benthos/sql/processor_neosync_mssql.go b/worker/pkg/benthos/sql/processor_neosync_mssql.go new file mode 100644 index 0000000000..1eed35c514 --- /dev/null +++ b/worker/pkg/benthos/sql/processor_neosync_mssql.go @@ -0,0 +1,139 @@ +package neosync_benthos_sql + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nucleuscloud/neosync/internal/gotypeutil" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" + "github.com/warpstreamlabs/bento/public/service" +) + +func neosyncToMssqlProcessorConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Field(service.NewStringListField("columns")). + Field(service.NewStringMapField("column_data_types")). + Field(service.NewAnyMapField("column_default_properties")) +} + +func RegisterNeosyncToMssqlProcessor(env *service.Environment) error { + return env.RegisterBatchProcessor( + "neosync_to_mssql", + neosyncToMssqlProcessorConfig(), + func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchProcessor, error) { + proc, err := newNeosyncToMssqlProcessor(conf, mgr) + if err != nil { + return nil, err + } + return proc, nil + }) +} + +type neosyncToMssqlProcessor struct { + logger *service.Logger + columns []string + columnDataTypes map[string]string + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties +} + +func newNeosyncToMssqlProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToMssqlProcessor, error) { + columns, err := conf.FieldStringList("columns") + if err != nil { + return nil, err + } + + columnDataTypes, err := conf.FieldStringMap("column_data_types") + if err != nil { + return nil, err + } + + columnDefaultPropertiesConfig, err := conf.FieldAnyMap("column_default_properties") + if err != nil { + return nil, err + } + + columnDefaultProperties, err := getColumnDefaultProperties(columnDefaultPropertiesConfig) + if err != nil { + return nil, err + } + + return &neosyncToMssqlProcessor{ + logger: mgr.Logger(), + columns: columns, + columnDataTypes: columnDataTypes, + columnDefaultProperties: columnDefaultProperties, + }, nil +} + +func (p *neosyncToMssqlProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { + newBatch := make(service.MessageBatch, 0, len(batch)) + for _, msg := range batch { + root, err := msg.AsStructuredMut() + if err != nil { + return nil, err + } + newRoot, err := transformNeosyncToMssql(p.logger, root, p.columns, p.columnDefaultProperties) + if err != nil { + return nil, err + } + newMsg := msg.Copy() + newMsg.SetStructured(newRoot) + newBatch = append(newBatch, newMsg) + } + + if len(newBatch) == 0 { + return nil, nil + } + return []service.MessageBatch{newBatch}, nil +} + +func (m *neosyncToMssqlProcessor) Close(context.Context) error { + return nil +} + +func transformNeosyncToMssql( + logger *service.Logger, + root any, + columns []string, + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) (map[string]any, error) { + rootMap, ok := root.(map[string]any) + if !ok { + return nil, fmt.Errorf("root value must be a map[string]any") + } + + newMap := make(map[string]any) + for col, val := range rootMap { + // Skip values that aren't in the column list to handle circular references + if !isColumnInList(col, columns) { + continue + } + + colDefaults := columnDefaultProperties[col] + // sqlserver doesn't support default values. must be removed + if colDefaults != nil && colDefaults.HasDefaultTransformer { + continue + } + + newVal, err := getMssqlValue(val) + if err != nil { + logger.Warn(err.Error()) + } + newMap[col] = newVal + } + + return newMap, nil +} + +func getMssqlValue(value any) (any, error) { + if gotypeutil.IsMap(value) { + bits, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("unable to marshal JSON: %w", err) + } + return bits, nil + } + + return value, nil +} diff --git a/worker/pkg/benthos/sql/processor_neosync_mssql_test.go b/worker/pkg/benthos/sql/processor_neosync_mssql_test.go new file mode 100644 index 0000000000..13cb4a4348 --- /dev/null +++ b/worker/pkg/benthos/sql/processor_neosync_mssql_test.go @@ -0,0 +1,222 @@ +package neosync_benthos_sql + +import ( + "context" + "testing" + + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" + "github.com/stretchr/testify/require" + "github.com/warpstreamlabs/bento/public/service" +) + +func Test_transformNeosyncToMssql(t *testing.T) { + logger := &service.Logger{} + columns := []string{"id", "name", "data", "default_col"} + columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{ + "default_col": {HasDefaultTransformer: true}, + } + + t.Run("handles basic values", func(t *testing.T) { + input := map[string]any{ + "id": 1, + "name": "test", + "data": map[string]string{"foo": "bar"}, + "default_col": "should be skipped", + } + + result, err := transformNeosyncToMssql(logger, input, columns, columnDefaultProperties) + require.NoError(t, err) + + require.Equal(t, 1, result["id"]) + require.Equal(t, "test", result["name"]) + require.Equal(t, []byte(`{"foo":"bar"}`), result["data"]) + _, exists := result["default_col"] + require.False(t, exists) + }) + + t.Run("handles nil values", func(t *testing.T) { + input := map[string]any{ + "id": nil, + "name": nil, + } + + result, err := transformNeosyncToMssql(logger, input, columns, columnDefaultProperties) + require.NoError(t, err) + + require.Nil(t, result["id"]) + require.Nil(t, result["name"]) + }) + + t.Run("skips columns not in column list", func(t *testing.T) { + input := map[string]any{ + "id": 1, + "name": "test", + "unknown_column": "should not appear", + } + + result, err := transformNeosyncToMssql(logger, input, columns, columnDefaultProperties) + require.NoError(t, err) + + require.Equal(t, 1, result["id"]) + require.Equal(t, "test", result["name"]) + _, exists := result["unknown_column"] + require.False(t, exists) + }) + + t.Run("returns error for invalid root type", func(t *testing.T) { + result, err := transformNeosyncToMssql(logger, "invalid", columns, columnDefaultProperties) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "root value must be a map[string]any") + }) +} + +func Test_getMssqlValue(t *testing.T) { + t.Run("marshals json for map value", func(t *testing.T) { + input := map[string]string{"foo": "bar"} + result, err := getMssqlValue(input) + require.NoError(t, err) + require.Equal(t, []byte(`{"foo":"bar"}`), result) + }) + + t.Run("returns original value for non-map types", func(t *testing.T) { + result, err := getMssqlValue("test") + require.NoError(t, err) + require.Equal(t, "test", result) + }) + + t.Run("handles nil value", func(t *testing.T) { + result, err := getMssqlValue(nil) + require.NoError(t, err) + require.Nil(t, result) + }) +} + +func Test_NeosyncToMssqlProcessor(t *testing.T) { + conf := ` +columns: + - id + - name + - age + - balance + - is_active + - created_at + - default_value +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp + default_value: text +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false + default_value: + has_default_transformer: true +` + spec := neosyncToMssqlProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToMssqlProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + "default_value": "some default", + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + expected := map[string]any{ + "id": msgMap["id"], + "name": msgMap["name"], + "age": msgMap["age"], + "balance": msgMap["balance"], + "is_active": msgMap["is_active"], + "created_at": msgMap["created_at"], + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} + +func Test_NeosyncToMssqlProcessor_SubsetColumns(t *testing.T) { + conf := ` +columns: + - id + - name +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false +` + spec := neosyncToMssqlProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToMssqlProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + expected := map[string]any{ + "id": msgMap["id"], + "name": msgMap["name"], + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} diff --git a/worker/pkg/benthos/sql/processor_neosync_mysql.go b/worker/pkg/benthos/sql/processor_neosync_mysql.go new file mode 100644 index 0000000000..481638ffc8 --- /dev/null +++ b/worker/pkg/benthos/sql/processor_neosync_mysql.go @@ -0,0 +1,167 @@ +package neosync_benthos_sql + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/doug-martin/goqu/v9" + mysqlutil "github.com/nucleuscloud/neosync/internal/mysql" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" + "github.com/warpstreamlabs/bento/public/service" +) + +func neosyncToMysqlProcessorConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Field(service.NewStringListField("columns")). + Field(service.NewStringMapField("column_data_types")). + Field(service.NewAnyMapField("column_default_properties")) +} + +func RegisterNeosyncToMysqlProcessor(env *service.Environment) error { + return env.RegisterBatchProcessor( + "neosync_to_mysql", + neosyncToMysqlProcessorConfig(), + func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchProcessor, error) { + proc, err := newNeosyncToMysqlProcessor(conf, mgr) + if err != nil { + return nil, err + } + return proc, nil + }) +} + +type neosyncToMysqlProcessor struct { + logger *service.Logger + columns []string + columnDataTypes map[string]string + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties +} + +func newNeosyncToMysqlProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToMysqlProcessor, error) { + columns, err := conf.FieldStringList("columns") + if err != nil { + return nil, err + } + + columnDataTypes, err := conf.FieldStringMap("column_data_types") + if err != nil { + return nil, err + } + + columnDefaultPropertiesConfig, err := conf.FieldAnyMap("column_default_properties") + if err != nil { + return nil, err + } + + columnDefaultProperties, err := getColumnDefaultProperties(columnDefaultPropertiesConfig) + if err != nil { + return nil, err + } + + return &neosyncToMysqlProcessor{ + logger: mgr.Logger(), + columns: columns, + columnDataTypes: columnDataTypes, + columnDefaultProperties: columnDefaultProperties, + }, nil +} + +func (p *neosyncToMysqlProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { + newBatch := make(service.MessageBatch, 0, len(batch)) + for _, msg := range batch { + root, err := msg.AsStructuredMut() + if err != nil { + return nil, err + } + newRoot, err := transformNeosyncToMysql(p.logger, root, p.columns, p.columnDataTypes, p.columnDefaultProperties) + if err != nil { + return nil, err + } + newMsg := msg.Copy() + newMsg.SetStructured(newRoot) + newBatch = append(newBatch, newMsg) + } + + if len(newBatch) == 0 { + return nil, nil + } + return []service.MessageBatch{newBatch}, nil +} + +func (m *neosyncToMysqlProcessor) Close(context.Context) error { + return nil +} + +func transformNeosyncToMysql( + logger *service.Logger, + root any, + columns []string, + columnDataTypes map[string]string, + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) (map[string]any, error) { + rootMap, ok := root.(map[string]any) + if !ok { + return nil, fmt.Errorf("root value must be a map[string]any") + } + + newMap := make(map[string]any) + for col, val := range rootMap { + // Skip values that aren't in the column list to handle circular references + if !isColumnInList(col, columns) { + continue + } + colDefaults := columnDefaultProperties[col] + datatype := columnDataTypes[col] + newVal, err := getMysqlValue(val, colDefaults, datatype) + if err != nil { + logger.Warn(err.Error()) + } + newMap[col] = newVal + } + + return newMap, nil +} + +func getMysqlValue(value any, colDefaults *neosync_benthos.ColumnDefaultProperties, datatype string) (any, error) { + if colDefaults != nil && colDefaults.HasDefaultTransformer { + return goqu.Default(), nil + } + + switch v := value.(type) { + case nil: + return v, nil + case []byte: + value, err := handleMysqlByteSlice(v, datatype) + if err != nil { + return nil, fmt.Errorf("unable to handle byte slice: %w", err) + } + return value, nil + default: + if mysqlutil.IsJsonDataType(datatype) { + bits, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("unable to marshal JSON: %w", err) + } + return bits, nil + } + return v, nil + } +} + +func handleMysqlByteSlice(v []byte, datatype string) (any, error) { + if datatype == "bit" { + bit, err := convertStringToBit(string(v)) + if err != nil { + return nil, fmt.Errorf("unable to convert bit string to SQL bit []byte: %w", err) + } + return bit, nil + } else if mysqlutil.IsJsonDataType(datatype) { + validJson, err := getValidJson(v) + if err != nil { + return nil, fmt.Errorf("unable to get valid json: %w", err) + } + return validJson, nil + } + return v, nil +} diff --git a/worker/pkg/benthos/sql/processor_neosync_mysql_test.go b/worker/pkg/benthos/sql/processor_neosync_mysql_test.go new file mode 100644 index 0000000000..1a3a9266b1 --- /dev/null +++ b/worker/pkg/benthos/sql/processor_neosync_mysql_test.go @@ -0,0 +1,263 @@ +package neosync_benthos_sql + +import ( + "context" + "encoding/json" + "testing" + + "github.com/doug-martin/goqu/v9" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" + "github.com/stretchr/testify/require" + "github.com/warpstreamlabs/bento/public/service" +) + +func Test_transformNeosyncToMysql(t *testing.T) { + logger := &service.Logger{} + columns := []string{"id", "name", "data", "bits", "default_col"} + columnDataTypes := map[string]string{ + "id": "int", + "name": "varchar", + "data": "json", + "bits": "bit", + } + columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{ + "default_col": {HasDefaultTransformer: true}, + } + + t.Run("handles basic values", func(t *testing.T) { + input := map[string]any{ + "id": 1, + "name": "test", + "data": map[string]string{"foo": "bar"}, + "bits": []byte("1"), + "default_col": "should be default", + } + + result, err := transformNeosyncToMysql(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + + require.Equal(t, 1, result["id"]) + require.Equal(t, "test", result["name"]) + require.Equal(t, []byte(`{"foo":"bar"}`), result["data"]) + require.Equal(t, []byte{1}, result["bits"]) + require.Equal(t, goqu.Default(), result["default_col"]) + }) + + t.Run("handles nil values", func(t *testing.T) { + input := map[string]any{ + "id": nil, + "name": nil, + } + + result, err := transformNeosyncToMysql(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + + require.Nil(t, result["id"]) + require.Nil(t, result["name"]) + }) + + t.Run("skips columns not in column list", func(t *testing.T) { + input := map[string]any{ + "id": 1, + "name": "test", + "unknown_column": "should not appear", + } + + result, err := transformNeosyncToMysql(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + + require.Equal(t, 1, result["id"]) + require.Equal(t, "test", result["name"]) + _, exists := result["unknown_column"] + require.False(t, exists) + }) + + t.Run("returns error for invalid root type", func(t *testing.T) { + result, err := transformNeosyncToMysql(logger, "invalid", columns, columnDataTypes, columnDefaultProperties) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "root value must be a map[string]any") + }) +} + +func Test_getMysqlValue(t *testing.T) { + t.Run("returns default for column with default transformer", func(t *testing.T) { + colDefaults := &neosync_benthos.ColumnDefaultProperties{HasDefaultTransformer: true} + result, err := getMysqlValue("test", colDefaults, "varchar") + require.NoError(t, err) + require.Equal(t, goqu.Default(), result) + }) + + t.Run("marshals json for json datatype", func(t *testing.T) { + input := map[string]string{"foo": "bar"} + result, err := getMysqlValue(input, nil, "json") + require.NoError(t, err) + require.Equal(t, []byte(`{"foo":"bar"}`), result) + }) + + t.Run("handles bit datatype", func(t *testing.T) { + result, err := getMysqlValue([]byte("1"), nil, "bit") + require.NoError(t, err) + require.Equal(t, []byte{1}, result) + }) + + t.Run("returns original value for non-special cases", func(t *testing.T) { + result, err := getMysqlValue("test", nil, "varchar") + require.NoError(t, err) + require.Equal(t, "test", result) + }) +} + +func Test_handleMysqlByteSlice(t *testing.T) { + t.Run("converts bit string to bytes", func(t *testing.T) { + result, err := handleMysqlByteSlice([]byte("1"), "bit") + require.NoError(t, err) + require.Equal(t, []byte{1}, result) + }) + + t.Run("returns original bytes for non-bit type", func(t *testing.T) { + input := []byte("test") + result, err := handleMysqlByteSlice(input, "varchar") + require.NoError(t, err) + require.Equal(t, input, result) + }) +} + +func Test_NeosyncToMysqlProcessor(t *testing.T) { + conf := ` +columns: + - id + - name + - age + - balance + - is_active + - created_at + - metadata + - default_value +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp + metadata: json + default_value: text +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false + default_value: + has_default_transformer: true +` + spec := neosyncToMysqlProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToMysqlProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + "metadata": map[string]string{"key": "value"}, + "default_value": "some default", + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + jsonBytes, err := json.Marshal(msgMap["metadata"]) + require.NoError(t, err) + + expected := map[string]any{ + "id": msgMap["id"], + "name": msgMap["name"], + "age": msgMap["age"], + "balance": msgMap["balance"], + "is_active": msgMap["is_active"], + "created_at": msgMap["created_at"], + "metadata": jsonBytes, + "default_value": goqu.Default(), + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} + +func Test_NeosyncToMysqlProcessor_SubsetColumns(t *testing.T) { + conf := ` +columns: + - id + - name +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp + metadata: json +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false +` + spec := neosyncToMysqlProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToMysqlProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + "metadata": map[string]string{"key": "value"}, + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + expected := map[string]any{ + "id": msgMap["id"], + "name": msgMap["name"], + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} diff --git a/worker/pkg/benthos/sql/processor_neosync_pgx.go b/worker/pkg/benthos/sql/processor_neosync_pgx.go index 0f1335cd45..0a3e572da4 100644 --- a/worker/pkg/benthos/sql/processor_neosync_pgx.go +++ b/worker/pkg/benthos/sql/processor_neosync_pgx.go @@ -2,16 +2,27 @@ package neosync_benthos_sql import ( "context" + "encoding/binary" + "encoding/json" "fmt" + "slices" + "strconv" + "strings" + "github.com/doug-martin/goqu/v9" "github.com/lib/pq" "github.com/nucleuscloud/neosync/internal/gotypeutil" neosynctypes "github.com/nucleuscloud/neosync/internal/neosync-types" + pgutil "github.com/nucleuscloud/neosync/internal/postgres" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" "github.com/warpstreamlabs/bento/public/service" ) func neosyncToPgxProcessorConfig() *service.ConfigSpec { - return service.NewConfigSpec().Field(service.NewStringMapField("column_data_types")) + return service.NewConfigSpec(). + Field(service.NewStringListField("columns")). + Field(service.NewStringMapField("column_data_types")). + Field(service.NewAnyMapField("column_default_properties")) } func RegisterNeosyncToPgxProcessor(env *service.Environment) error { @@ -19,19 +30,48 @@ func RegisterNeosyncToPgxProcessor(env *service.Environment) error { "neosync_to_pgx", neosyncToPgxProcessorConfig(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchProcessor, error) { - proc := newNeosyncToPgxProcessor(conf, mgr) + proc, err := newNeosyncToPgxProcessor(conf, mgr) + if err != nil { + return nil, err + } return proc, nil }) } type neosyncToPgxProcessor struct { - logger *service.Logger + logger *service.Logger + columns []string + columnDataTypes map[string]string + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties } -func newNeosyncToPgxProcessor(_ *service.ParsedConfig, mgr *service.Resources) *neosyncToPgxProcessor { - return &neosyncToPgxProcessor{ - logger: mgr.Logger(), +func newNeosyncToPgxProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToPgxProcessor, error) { + columnDataTypes, err := conf.FieldStringMap("column_data_types") + if err != nil { + return nil, err } + + columns, err := conf.FieldStringList("columns") + if err != nil { + return nil, err + } + + columnDefaultPropertiesConfig, err := conf.FieldAnyMap("column_default_properties") + if err != nil { + return nil, err + } + + columnDefaultProperties, err := getColumnDefaultProperties(columnDefaultPropertiesConfig) + if err != nil { + return nil, err + } + + return &neosyncToPgxProcessor{ + logger: mgr.Logger(), + columns: columns, + columnDataTypes: columnDataTypes, + columnDefaultProperties: columnDefaultProperties, + }, nil } func (p *neosyncToPgxProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { @@ -41,7 +81,10 @@ func (p *neosyncToPgxProcessor) ProcessBatch(ctx context.Context, batch service. if err != nil { return nil, err } - newRoot := p.transform(root) + newRoot, err := transformNeosyncToPgx(p.logger, root, p.columns, p.columnDataTypes, p.columnDefaultProperties) + if err != nil { + return nil, err + } newMsg := msg.Copy() newMsg.SetStructured(newRoot) newBatch = append(newBatch, newMsg) @@ -56,32 +99,202 @@ func (p *neosyncToPgxProcessor) ProcessBatch(ctx context.Context, batch service. func (m *neosyncToPgxProcessor) Close(context.Context) error { return nil } +func transformNeosyncToPgx( + logger *service.Logger, + root any, + columns []string, + columnDataTypes map[string]string, + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) (map[string]any, error) { + rootMap, ok := root.(map[string]any) + if !ok { + return nil, fmt.Errorf("root value must be a map[string]any") + } -func (p *neosyncToPgxProcessor) transform(root any) any { - switch v := root.(type) { - case map[string]any: - newMap := make(map[string]any) - for k, v2 := range v { - newValue := p.transform(v2) - newMap[k] = newValue + newMap := make(map[string]any) + for col, val := range rootMap { + // Skip values that aren't in the column list to handle circular references + if !isColumnInList(col, columns) { + continue + } + colDefaults := columnDefaultProperties[col] + datatype := columnDataTypes[col] + newVal, err := getPgxValue(val, colDefaults, datatype) + if err != nil { + logger.Warn(err.Error()) } - return newMap + newMap[col] = newVal + } + + return newMap, nil +} + +func getPgxValue(value any, colDefaults *neosync_benthos.ColumnDefaultProperties, datatype string) (any, error) { + value, isNeosyncValue, err := getPgxNeosyncValue(value) + if err != nil { + return nil, err + } + if isNeosyncValue { + return value, nil + } + + if colDefaults != nil && colDefaults.HasDefaultTransformer { + return goqu.Default(), nil + } + + switch v := value.(type) { case nil: - return v + return v, nil + case []byte: + value, err := handlePgxByteSlice(v, datatype) + if err != nil { + return nil, fmt.Errorf("unable to handle byte slice: %w", err) + } + return value, nil default: - // Check if the type implements Value() method - if valuer, ok := v.(neosynctypes.NeosyncPgxValuer); ok { - value, err := valuer.ValuePgx() + if pgutil.IsJsonPgDataType(datatype) { + bits, err := json.Marshal(value) if err != nil { - p.logger.Warn(fmt.Sprintf("unable to get PGX value: %v", err)) - return v - } - if gotypeutil.IsSlice(value) { - return pq.Array(value) + return nil, fmt.Errorf("unable to marshal JSON: %w", err) } - return value + return bits, nil + } else if gotypeutil.IsMultiDimensionalSlice(v) || gotypeutil.IsSliceOfMaps(v) { + return goqu.Literal(pgutil.FormatPgArrayLiteral(v, datatype)), nil + } else if gotypeutil.IsSlice(v) { + return pq.Array(v), nil + } + return v, nil + } +} + +func getPgxNeosyncValue(root any) (value any, isNeosyncValue bool, err error) { + if valuer, ok := root.(neosynctypes.NeosyncPgxValuer); ok { + value, err := valuer.ValuePgx() + if err != nil { + return nil, false, fmt.Errorf("unable to get PGX value from NeosyncPgxValuer: %w", err) + } + if gotypeutil.IsSlice(value) { + return pq.Array(value), true, nil + } + return value, true, nil + } + return root, false, nil +} + +func handlePgxByteSlice(v []byte, datatype string) (any, error) { + if pgutil.IsPgArrayColumnDataType(datatype) { + // this handles the case where the array is in the form {1,2,3} + if strings.HasPrefix(string(v), "{") { + return string(v), nil + } + pgarray, err := processPgArrayFromJson(v, datatype) + if err != nil { + return nil, fmt.Errorf("unable to process PG Array: %w", err) + } + return pgarray, nil + } + switch datatype { + case "bit": + bit, err := convertStringToBit(string(v)) + if err != nil { + return nil, fmt.Errorf("unable to convert bit string to SQL bit []byte: %w", err) + } + return bit, nil + case "json", "jsonb": + validJson, err := getValidJson(v) + if err != nil { + return nil, fmt.Errorf("unable to get valid json: %w", err) + } + return validJson, nil + case "money", "uuid", "time with time zone", "timestamp with time zone": + // Convert UUID []byte to string before inserting since postgres driver stores uuid bytes in different order + return string(v), nil + } + return v, nil +} + +// this expects the bits to be in the form [1,2,3] +func processPgArrayFromJson(bits []byte, datatype string) (any, error) { + var pgarray []any + err := json.Unmarshal(bits, &pgarray) + if err != nil { + return nil, err + } + switch datatype { + case "json[]", "jsonb[]": + jsonArray, err := stringifyJsonArray(pgarray) + if err != nil { + return nil, err + } + return pq.Array(jsonArray), nil + default: + return pq.Array(pgarray), nil + } +} + +// handles case where json strings are not quoted +func getValidJson(jsonData []byte) ([]byte, error) { + isValidJson := json.Valid(jsonData) + if isValidJson { + return jsonData, nil + } + + quotedData, err := json.Marshal(string(jsonData)) + if err != nil { + return nil, err + } + return quotedData, nil +} + +func stringifyJsonArray(pgarray []any) ([]string, error) { + jsonArray := make([]string, len(pgarray)) + for i, item := range pgarray { + bytes, err := json.Marshal(item) + if err != nil { + return nil, err + } + jsonArray[i] = string(bytes) + } + return jsonArray, nil +} + +func convertStringToBit(bitString string) ([]byte, error) { + val, err := strconv.ParseUint(bitString, 2, len(bitString)) + if err != nil { + return nil, err + } + + // Always allocate 8 bytes for PutUint64 + bytes := make([]byte, 8) + binary.BigEndian.PutUint64(bytes, val) + + // Calculate actual needed bytes and return only those + neededBytes := (len(bitString) + 7) / 8 + return bytes[len(bytes)-neededBytes:], nil +} + +func isColumnInList(column string, columns []string) bool { + return slices.Contains(columns, column) +} + +func getColumnDefaultProperties(columnDefaultPropertiesConfig map[string]*service.ParsedConfig) (map[string]*neosync_benthos.ColumnDefaultProperties, error) { + columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{} + for key, properties := range columnDefaultPropertiesConfig { + props, err := properties.FieldAny() + if err != nil { + return nil, err + } + jsonData, err := json.Marshal(props) + if err != nil { + return nil, fmt.Errorf("failed to marshal properties for key %s: %w", key, err) + } + + var colDefaults neosync_benthos.ColumnDefaultProperties + if err := json.Unmarshal(jsonData, &colDefaults); err != nil { + return nil, fmt.Errorf("failed to unmarshal properties for key %s: %w", key, err) } - return v + columnDefaultProperties[key] = &colDefaults } + return columnDefaultProperties, nil } diff --git a/worker/pkg/benthos/sql/processor_neosync_pgx_test.go b/worker/pkg/benthos/sql/processor_neosync_pgx_test.go new file mode 100644 index 0000000000..a29f58e0c3 --- /dev/null +++ b/worker/pkg/benthos/sql/processor_neosync_pgx_test.go @@ -0,0 +1,558 @@ +package neosync_benthos_sql + +import ( + "context" + "database/sql/driver" + "encoding/json" + "testing" + + "github.com/doug-martin/goqu/v9" + "github.com/lib/pq" + neosynctypes "github.com/nucleuscloud/neosync/internal/neosync-types" + pgutil "github.com/nucleuscloud/neosync/internal/postgres" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" + "github.com/stretchr/testify/require" + "github.com/warpstreamlabs/bento/public/service" +) + +func Test_convertStringToBit(t *testing.T) { + t.Run("8 bits", func(t *testing.T) { + got, err := convertStringToBit("10101010") + require.NoError(t, err) + expected := []byte{170} + require.Equalf(t, expected, got, "got %v, want %v", got, expected) + }) + + t.Run("1 bit", func(t *testing.T) { + got, err := convertStringToBit("1") + require.NoError(t, err) + expected := []byte{1} + require.Equalf(t, expected, got, "got %v, want %v", got, expected) + }) + + t.Run("16 bits", func(t *testing.T) { + got, err := convertStringToBit("1010101010101010") + require.NoError(t, err) + expected := []byte{170, 170} + require.Equalf(t, expected, got, "got %v, want %v", got, expected) + }) + + t.Run("24 bits", func(t *testing.T) { + got, err := convertStringToBit("101010101111111100000000") + require.NoError(t, err) + expected := []byte{170, 255, 0} + require.Equalf(t, expected, got, "got %v, want %v", got, expected) + }) + + t.Run("invalid binary string", func(t *testing.T) { + _, err := convertStringToBit("102") + require.Error(t, err) + }) + + t.Run("empty string", func(t *testing.T) { + _, err := convertStringToBit("") + require.Error(t, err) + }) +} + +func Test_getValidJson(t *testing.T) { + t.Run("already valid json", func(t *testing.T) { + input := []byte(`{"key": "value"}`) + got, err := getValidJson(input) + require.NoError(t, err) + require.Equal(t, input, got) + }) + + t.Run("unquoted string", func(t *testing.T) { + input := []byte(`hello world`) + got, err := getValidJson(input) + require.NoError(t, err) + expected := []byte(`"hello world"`) + require.Equal(t, expected, got) + }) +} + +func Test_stringifyJsonArray(t *testing.T) { + t.Run("array of objects", func(t *testing.T) { + input := []any{ + map[string]any{"name": "Alice"}, + map[string]any{"name": "Bob"}, + } + got, err := stringifyJsonArray(input) + require.NoError(t, err) + expected := []string{`{"name":"Alice"}`, `{"name":"Bob"}`} + require.Equal(t, expected, got) + }) + + t.Run("empty array", func(t *testing.T) { + got, err := stringifyJsonArray([]any{}) + require.NoError(t, err) + require.Equal(t, []string{}, got) + }) +} + +func Test_isColumnInList(t *testing.T) { + columns := []string{"id", "name", "email"} + + t.Run("column exists", func(t *testing.T) { + require.True(t, isColumnInList("name", columns)) + }) + + t.Run("column does not exist", func(t *testing.T) { + require.False(t, isColumnInList("age", columns)) + }) + + t.Run("empty column list", func(t *testing.T) { + require.False(t, isColumnInList("name", []string{})) + }) +} + +func Test_processPgArrayFromJson(t *testing.T) { + t.Run("json array", func(t *testing.T) { + input := []byte(`[{"tag":"cool"},{"tag":"awesome"}]`) + got, err := processPgArrayFromJson(input, "json[]") + require.NoError(t, err) + + // Convert back to string for comparison since pq.Array isn't easily comparable + arr, ok := got.(interface{ Value() (driver.Value, error) }) + require.True(t, ok) + val, err := arr.Value() + require.NoError(t, err) + strArr, ok := val.(string) + require.True(t, ok) + require.Equal(t, `{"{\"tag\":\"cool\"}","{\"tag\":\"awesome\"}"}`, strArr) + }) + + t.Run("invalid json", func(t *testing.T) { + input := []byte(`[invalid json]`) + _, err := processPgArrayFromJson(input, "json[]") + require.Error(t, err) + }) +} + +func Test_transformNeosyncToPgx(t *testing.T) { + logger := &service.Logger{} + columns := []string{"id", "name", "data"} + columnDataTypes := map[string]string{ + "id": "integer", + "name": "text", + "data": "json", + } + columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{ + "id": {HasDefaultTransformer: true}, + } + + t.Run("transforms values correctly", func(t *testing.T) { + input := map[string]any{ + "id": 123, + "name": "test", + "data": map[string]string{"foo": "bar"}, + } + + got, err := transformNeosyncToPgx(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + + // id should be DEFAULT due to HasDefaultTransformer + idVal, ok := got["id"].(goqu.Expression) + require.True(t, ok) + require.NotNil(t, idVal) + + require.Equal(t, "test", got["name"]) + + // data should be JSON encoded + dataBytes, ok := got["data"].([]byte) + require.True(t, ok) + require.JSONEq(t, `{"foo":"bar"}`, string(dataBytes)) + }) + + t.Run("skips columns not in list", func(t *testing.T) { + input := map[string]any{ + "id": 123, + "name": "test", + "ignored": "value", + } + + got, err := transformNeosyncToPgx(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + require.NotContains(t, got, "ignored") + }) + + t.Run("handles nil values", func(t *testing.T) { + input := map[string]any{ + "id": nil, + "name": nil, + } + + got, err := transformNeosyncToPgx(logger, input, columns, columnDataTypes, columnDefaultProperties) + require.NoError(t, err) + require.Nil(t, got["name"]) + }) + + t.Run("invalid input type", func(t *testing.T) { + _, err := transformNeosyncToPgx(logger, "not a map", columns, columnDataTypes, columnDefaultProperties) + require.Error(t, err) + }) +} +func TestHandlePgxByteSlice(t *testing.T) { + t.Run("handles array types", func(t *testing.T) { + input := []byte(`[1,2,3]`) + got, err := handlePgxByteSlice(input, "integer[]") + require.NoError(t, err) + + // Should be wrapped in pq.Array + arr, ok := got.(interface{ Value() (driver.Value, error) }) + require.True(t, ok) + + val, err := arr.Value() + require.NoError(t, err) + require.Equal(t, "{1,2,3}", val) + }) + + t.Run("handles bit type", func(t *testing.T) { + input := []byte("1010") + got, err := handlePgxByteSlice(input, "bit") + require.NoError(t, err) + + bytes, ok := got.([]byte) + require.True(t, ok) + require.Equal(t, []byte{10}, bytes) // 1010 binary = 10 decimal + }) + + t.Run("handles json type", func(t *testing.T) { + input := []byte(`{"foo":"bar"}`) + got, err := handlePgxByteSlice(input, "json") + require.NoError(t, err) + + bytes, ok := got.([]byte) + require.True(t, ok) + require.JSONEq(t, `{"foo":"bar"}`, string(bytes)) + }) + + t.Run("handles jsonb type", func(t *testing.T) { + input := []byte(`{"foo":"bar"}`) + got, err := handlePgxByteSlice(input, "jsonb") + require.NoError(t, err) + + bytes, ok := got.([]byte) + require.True(t, ok) + require.JSONEq(t, `{"foo":"bar"}`, string(bytes)) + }) + + t.Run("handles uuid type", func(t *testing.T) { + input := []byte("550e8400-e29b-41d4-a716-446655440000") + got, err := handlePgxByteSlice(input, "uuid") + require.NoError(t, err) + + str, ok := got.(string) + require.True(t, ok) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", str) + }) + + t.Run("handles timestamp with time zone", func(t *testing.T) { + input := []byte("2023-01-01 12:00:00+00") + got, err := handlePgxByteSlice(input, "timestamp with time zone") + require.NoError(t, err) + + str, ok := got.(string) + require.True(t, ok) + require.Equal(t, "2023-01-01 12:00:00+00", str) + }) + + t.Run("handles time with time zone", func(t *testing.T) { + input := []byte("12:00:00+00") + got, err := handlePgxByteSlice(input, "time with time zone") + require.NoError(t, err) + + str, ok := got.(string) + require.True(t, ok) + require.Equal(t, "12:00:00+00", str) + }) + + t.Run("handles money type", func(t *testing.T) { + input := []byte("$123.45") + got, err := handlePgxByteSlice(input, "money") + require.NoError(t, err) + + str, ok := got.(string) + require.True(t, ok) + require.Equal(t, "$123.45", str) + }) + + t.Run("returns original bytes for unknown type", func(t *testing.T) { + input := []byte("test") + got, err := handlePgxByteSlice(input, "text") + require.NoError(t, err) + require.Equal(t, input, got) + }) + + t.Run("handles invalid bit string", func(t *testing.T) { + input := []byte("not a bit string") + _, err := handlePgxByteSlice(input, "bit") + require.Error(t, err) + }) + + t.Run("handles invalid json", func(t *testing.T) { + input := []byte("{invalid json}") + got, err := handlePgxByteSlice(input, "json") + require.NoError(t, err) + require.JSONEq(t, `"{invalid json}"`, string(got.([]byte))) + }) +} + +func Test_getPgxValue(t *testing.T) { + t.Run("handles json values", func(t *testing.T) { + testCases := []struct { + name string + input any + datatype string + expected []byte + }{ + { + name: "string value", + input: "value1", + datatype: "json", + expected: []byte(`"value1"`), + }, + { + name: "number value", + input: 42, + datatype: "jsonb", + expected: []byte(`42`), + }, + { + name: "boolean value", + input: true, + datatype: "json", + expected: []byte(`true`), + }, + { + name: "object value", + input: map[string]any{"key": "value"}, + datatype: "jsonb", + expected: []byte(`{"key":"value"}`), + }, + { + name: "array value", + input: []int{1, 2, 3}, + datatype: "json", + expected: []byte(`[1,2,3]`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := getPgxValue(tc.input, nil, tc.datatype) + require.NoError(t, err) + require.Equal(t, tc.expected, got) + }) + } + }) + + t.Run("handles default transformer", func(t *testing.T) { + colDefaults := &neosync_benthos.ColumnDefaultProperties{ + HasDefaultTransformer: true, + } + got, err := getPgxValue("test", colDefaults, "text") + require.NoError(t, err) + require.Equal(t, goqu.Default(), got) + }) + + t.Run("handles nil value", func(t *testing.T) { + got, err := getPgxValue(nil, nil, "text") + require.NoError(t, err) + require.Nil(t, got) + }) + + t.Run("handles byte slice", func(t *testing.T) { + input := []byte("test") + got, err := getPgxValue(input, nil, "text") + require.NoError(t, err) + require.Equal(t, input, got) + }) + + t.Run("handles slice", func(t *testing.T) { + input := []string{"a", "b", "c"} + got, err := getPgxValue(input, nil, "text[]") + require.NoError(t, err) + require.Equal(t, pq.Array(input), got) + }) + + t.Run("handles multidimensional slice", func(t *testing.T) { + input := [][]string{{"a", "b"}, {"c", "d"}} + got, err := getPgxValue(input, nil, "text[][]") + require.NoError(t, err) + require.Equal(t, goqu.Literal(pgutil.FormatPgArrayLiteral(input, "text[][]")), got) + }) + + t.Run("handles slice of maps", func(t *testing.T) { + input := []map[string]string{{"key": "value"}} + got, err := getPgxValue(input, nil, "jsonb[]") + require.NoError(t, err) + require.Equal(t, goqu.Literal(pgutil.FormatPgArrayLiteral(input, "jsonb[]")), got) + }) +} + +func Test_NeosyncToPgxProcessor(t *testing.T) { + conf := ` +columns: + - id + - name + - age + - balance + - is_active + - created_at + - tags + - metadata + - interval + - default_value +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp + tags: text[] + metadata: jsonb + interval: interval + default_value: text +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false + default_value: + has_default_transformer: true +` + spec := neosyncToPgxProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToPgxProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + interval, err := neosynctypes.NewInterval() + require.NoError(t, err) + interval.ScanPgx(map[string]any{ + "months": 1, + "days": 10, + "microseconds": 3600000000, + }) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + "tags": []string{"tag1", "tag2"}, + "metadata": map[string]string{"key": "value"}, + "interval": interval, + "default_value": "some value", + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + intervalVal, err := interval.ValuePgx() + jsonBytes, err := json.Marshal(msgMap["metadata"]) + require.NoError(t, err) + + require.NoError(t, err) + expected := map[string]any{ + "id": msgMap["id"], + "name": msgMap["name"], + "age": msgMap["age"], + "balance": msgMap["balance"], + "is_active": msgMap["is_active"], + "created_at": msgMap["created_at"], + "tags": pq.Array(msgMap["tags"]), + "metadata": jsonBytes, + "interval": intervalVal, + "default_value": goqu.Default(), + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} + +func Test_NeosyncToPgxProcessor_SubsetColumns(t *testing.T) { + conf := ` +columns: + - id + - name +column_data_types: + id: integer + name: text + age: integer + balance: double + is_active: boolean + created_at: timestamp + tags: text[] + metadata: jsonb + interval: interval +column_default_properties: + id: + has_default_transformer: false + name: + has_default_transformer: false +` + spec := neosyncToPgxProcessorConfig() + env := service.NewEnvironment() + + procConfig, err := spec.ParseYAML(conf, env) + require.NoError(t, err) + + proc, err := newNeosyncToPgxProcessor(procConfig, service.MockResources()) + require.NoError(t, err) + + msgMap := map[string]any{ + "id": 1, + "name": "test", + "age": 30, + "balance": 1000.50, + "is_active": true, + "created_at": "2023-01-01T00:00:00Z", + "tags": []string{"tag1", "tag2"}, + "metadata": map[string]string{"key": "value"}, + "interval": neosynctypes.Interval{ + Months: 1, + Days: 10, + Microseconds: 3600000000, + }, + } + msg := service.NewMessage(nil) + msg.SetStructured(msgMap) + batch := service.MessageBatch{ + msg, + } + + results, err := proc.ProcessBatch(context.Background(), batch) + require.NoError(t, err) + require.Len(t, results, 1) + require.Len(t, results[0], 1) + + val, err := results[0][0].AsStructured() + require.NoError(t, err) + + expected := map[string]any{ + "id": 1, + "name": "test", + } + require.Equal(t, expected, val) + + require.NoError(t, proc.Close(context.Background())) +} diff --git a/worker/pkg/query-builder/insert-query-builder.go b/worker/pkg/query-builder/insert-query-builder.go index 854a455b15..1b592214bb 100644 --- a/worker/pkg/query-builder/insert-query-builder.go +++ b/worker/pkg/query-builder/insert-query-builder.go @@ -1,33 +1,22 @@ package querybuilder import ( - "encoding/json" "fmt" "log/slog" "strings" "github.com/doug-martin/goqu/v9" - "github.com/doug-martin/goqu/v9/exp" - "github.com/lib/pq" sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" - gotypeutil "github.com/nucleuscloud/neosync/internal/gotypeutil" - mysqlutil "github.com/nucleuscloud/neosync/internal/mysql" - pgutil "github.com/nucleuscloud/neosync/internal/postgres" sqlserverutil "github.com/nucleuscloud/neosync/internal/sqlserver" - neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" ) func GetInsertBuilder( logger *slog.Logger, driver, schema, table string, - columns []string, opts ...InsertOption, ) (InsertQueryBuilder, error) { - options := &InsertOptions{ - columnDefaults: []*neosync_benthos.ColumnDefaultProperties{}, - columnDataTypes: []string{}, - } + options := &InsertOptions{} for _, opt := range opts { opt(options) } @@ -39,7 +28,6 @@ func GetInsertBuilder( logger: logger, schema: schema, table: table, - columns: columns, options: options, }, nil case sqlmanager_shared.MysqlDriver: @@ -48,7 +36,6 @@ func GetInsertBuilder( logger: logger, schema: schema, table: table, - columns: columns, options: options, }, nil case sqlmanager_shared.MssqlDriver: @@ -57,7 +44,6 @@ func GetInsertBuilder( logger: logger, schema: schema, table: table, - columns: columns, options: options, }, nil default: @@ -68,27 +54,19 @@ func GetInsertBuilder( // InsertQueryBuilder provides an interface for building SQL insert queries across different database drivers. type InsertQueryBuilder interface { // BuildInsertQuery generates a complete SQL insert statement for multiple rows of data. - BuildInsertQuery(rows [][]any) (query string, args []any, err error) - - // BuildPreparedInsertQuerySingleRow generates a prepared SQL insert statement for a single row. - BuildPreparedInsertQuerySingleRow() (query string, err error) - // BuildPreparedInsertArgs processes the input rows and returns properly formatted arguments for use with a prepared statement - BuildPreparedInsertArgs(rows [][]any) [][]any + BuildInsertQuery(rows []map[string]any) (query string, args []any, err error) } type InsertOption func(*InsertOptions) type InsertOptions struct { - rawInsertMode bool - onConflictDoNothing bool - columnDataTypes []string - columnDefaults []*neosync_benthos.ColumnDefaultProperties - prefix, suffix *string + shouldOverrideColumnDefault bool + onConflictDoNothing bool + prefix, suffix *string } -// WithRawInsertMode inserts data as is -func WithRawInsertMode() InsertOption { +func WithShouldOverrideColumnDefault() InsertOption { return func(opts *InsertOptions) { - opts.rawInsertMode = true + opts.shouldOverrideColumnDefault = true } } @@ -113,128 +91,37 @@ func WithOnConflictDoNothing() InsertOption { } } -// WithColumnDataTypes adds column datatypes -func WithColumnDataTypes(types []string) InsertOption { - return func(opts *InsertOptions) { - opts.columnDataTypes = types - } -} - -// WithColumnDefaults adds ColumnDefaultProperties -func WithColumnDefaults(defaults []*neosync_benthos.ColumnDefaultProperties) InsertOption { - return func(opts *InsertOptions) { - opts.columnDefaults = defaults - } -} - type PostgresDriver struct { driver string logger *slog.Logger schema, table string - columns []string options *InsertOptions } -func (d *PostgresDriver) BuildInsertQuery(rows [][]any) (query string, queryargs []any, err error) { - var goquRows []exp.Vals - if d.options.rawInsertMode { - goquRows = toGoquVals(updateDefaultVals(rows, d.options.columnDefaults)) - } else { - goquRows = toGoquVals(getPostgresVals(d.logger, rows, d.options.columnDataTypes, d.options.columnDefaults)) - } +func (d *PostgresDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { + goquRows := toGoquRecords(rows) - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, d.columns, goquRows, &d.options.onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &d.options.onConflictDoNothing) if err != nil { return "", nil, err } - if d.shouldOverrideColumnDefault(d.options.columnDefaults) { + if d.options.shouldOverrideColumnDefault { insertQuery = sqlmanager_postgres.BuildPgInsertIdentityAlwaysSql(insertQuery) } return insertQuery, args, err } -func (d *PostgresDriver) BuildPreparedInsertQuerySingleRow() (string, error) { - query, err := BuildPreparedInsertQuery(d.driver, d.schema, d.table, d.columns, 1, d.options.onConflictDoNothing) - if err != nil { - return "", err - } - - if d.shouldOverrideColumnDefault(d.options.columnDefaults) { - query = sqlmanager_postgres.BuildPgInsertIdentityAlwaysSql(query) - } - return query, err -} - -func (d *PostgresDriver) BuildPreparedInsertArgs(rows [][]any) [][]any { - if d.options.rawInsertMode { - return rows - } - return getPostgresVals(d.logger, rows, d.options.columnDataTypes, d.options.columnDefaults) -} - -// TODO move this logic to PGX processor -func getPostgresVals(logger *slog.Logger, rows [][]any, columnDataTypes []string, columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties) [][]any { - newVals := [][]any{} - for _, row := range rows { - newRow := []any{} - for i, a := range row { - var colDataType string - if i < len(columnDataTypes) { - colDataType = columnDataTypes[i] - } - var colDefaults *neosync_benthos.ColumnDefaultProperties - if i < len(columnDefaultProperties) { - colDefaults = columnDefaultProperties[i] - } - if pgutil.IsJsonPgDataType(colDataType) { - bits, err := json.Marshal(a) - if err != nil { - logger.Error("unable to marshal JSON", "error", err.Error()) - newRow = append(newRow, a) - continue - } - newRow = append(newRow, bits) - } else if gotypeutil.IsMultiDimensionalSlice(a) || gotypeutil.IsSliceOfMaps(a) { - newRow = append(newRow, goqu.Literal(pgutil.FormatPgArrayLiteral(a, colDataType))) - } else if gotypeutil.IsSlice(a) { - newRow = append(newRow, pq.Array(a)) - } else if colDefaults != nil && colDefaults.HasDefaultTransformer { - newRow = append(newRow, goqu.Literal(defaultStr)) - } else { - newRow = append(newRow, a) - } - } - newVals = append(newVals, newRow) - } - return newVals -} - -func (d *PostgresDriver) shouldOverrideColumnDefault(columnDefaults []*neosync_benthos.ColumnDefaultProperties) bool { - for _, cd := range columnDefaults { - if cd != nil && !cd.HasDefaultTransformer && cd.NeedsOverride { - return true - } - } - return false -} - type MysqlDriver struct { driver string logger *slog.Logger schema, table string - columns []string options *InsertOptions } -func (d *MysqlDriver) BuildInsertQuery(rows [][]any) (query string, queryargs []any, err error) { - var goquRows []exp.Vals - if d.options.rawInsertMode { - goquRows = toGoquVals(updateDefaultVals(rows, d.options.columnDefaults)) - } else { - goquRows = toGoquVals(getMysqlVals(d.logger, rows, d.options.columnDataTypes, d.options.columnDefaults)) - } +func (d *MysqlDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { + goquRows := toGoquRecords(rows) - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, d.columns, goquRows, &d.options.onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &d.options.onConflictDoNothing) if err != nil { return "", nil, err } @@ -248,81 +135,21 @@ func (d *MysqlDriver) BuildInsertQuery(rows [][]any) (query string, queryargs [] return insertQuery, args, err } -func (d *MysqlDriver) BuildPreparedInsertQuerySingleRow() (string, error) { - query, err := BuildPreparedInsertQuery(d.driver, d.schema, d.table, d.columns, 1, d.options.onConflictDoNothing) - if err != nil { - return "", err - } - - if d.options.prefix != nil && *d.options.prefix != "" { - query = addPrefix(query, *d.options.prefix) - } - if d.options.suffix != nil && *d.options.suffix != "" { - query = addSuffix(query, *d.options.suffix) - } - return query, err -} - -func (d *MysqlDriver) BuildPreparedInsertArgs(rows [][]any) [][]any { - if d.options.rawInsertMode { - return rows - } - return getMysqlVals(d.logger, rows, d.options.columnDataTypes, d.options.columnDefaults) -} - -func getMysqlVals(logger *slog.Logger, rows [][]any, columnDataTypes []string, columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties) [][]any { - newVals := [][]any{} - for _, row := range rows { - newRow := []any{} - for idx, a := range row { - var colDataType string - if idx < len(columnDataTypes) { - colDataType = columnDataTypes[idx] - } - var colDefaults *neosync_benthos.ColumnDefaultProperties - if idx < len(columnDefaultProperties) { - colDefaults = columnDefaultProperties[idx] - } - if colDefaults != nil && colDefaults.HasDefaultTransformer { - newRow = append(newRow, goqu.Literal(defaultStr)) - } else if mysqlutil.IsJsonDataType(colDataType) { - bits, err := json.Marshal(a) - if err != nil { - logger.Error("unable to marshal JSON", "error", err.Error()) - newRow = append(newRow, a) - continue - } - newRow = append(newRow, bits) - } else { - newRow = append(newRow, a) - } - } - newVals = append(newVals, newRow) - } - return newVals -} - type MssqlDriver struct { driver string logger *slog.Logger schema, table string - columns []string options *InsertOptions } -func (d *MssqlDriver) BuildInsertQuery(rows [][]any) (query string, queryargs []any, err error) { - processedCols, processedRow, processedColDefaults := d.filterOutDefaultIdentityColumns(d.columns, rows, d.options.columnDefaults) - if len(processedRow) == 0 { +func (d *MssqlDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { + if len(rows) == 0 || areAllRowsEmpty(rows) { return sqlserverutil.GeSqlServerDefaultValuesInsertSql(d.schema, d.table, len(rows)), []any{}, nil } - var goquRows []exp.Vals - if d.options.rawInsertMode { - goquRows = toGoquVals(processedRow) - } else { - goquRows = toGoquVals(getMssqlVals(d.logger, processedRow, processedColDefaults)) - } - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, processedCols, goquRows, &d.options.onConflictDoNothing) + goquRows := toGoquRecords(rows) + + insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &d.options.onConflictDoNothing) if err != nil { return "", nil, err } @@ -337,65 +164,13 @@ func (d *MssqlDriver) BuildInsertQuery(rows [][]any) (query string, queryargs [] return insertQuery, args, err } -func (d *MssqlDriver) BuildPreparedInsertQuerySingleRow() (string, error) { - query, err := BuildPreparedInsertQuery(d.driver, d.schema, d.table, d.columns, 1, d.options.onConflictDoNothing) - if err != nil { - return "", err - } - - if d.options.prefix != nil && *d.options.prefix != "" { - query = addPrefix(query, *d.options.prefix) - } - if d.options.suffix != nil && *d.options.suffix != "" { - query = addSuffix(query, *d.options.suffix) - } - return query, err -} - -func (d *MssqlDriver) BuildPreparedInsertArgs(rows [][]any) [][]any { - _, processedRow, processedColDefaults := d.filterOutDefaultIdentityColumns(d.columns, rows, d.options.columnDefaults) - if d.options.rawInsertMode { - return processedRow - } - return getMssqlVals(d.logger, processedRow, processedColDefaults) -} - -func getMssqlVals(logger *slog.Logger, rows [][]any, columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties) [][]any { - newVals := [][]any{} +func areAllRowsEmpty(rows []map[string]any) bool { for _, row := range rows { - newRow := []any{} - for idx, a := range row { - var colDefaults *neosync_benthos.ColumnDefaultProperties - if idx < len(columnDefaultProperties) { - colDefaults = columnDefaultProperties[idx] - } - if colDefaults != nil && colDefaults.HasDefaultTransformer { - newRow = append(newRow, goqu.Literal(defaultStr)) - } else if gotypeutil.IsMap(a) { - bits, err := gotypeutil.MapToJson(a) - if err != nil { - logger.Error("unable to marshal map to JSON", "error", err.Error()) - newRow = append(newRow, a) - } else { - newRow = append(newRow, bits) - } - } else { - newRow = append(newRow, a) - } + if len(row) > 0 { + return false } - - newVals = append(newVals, newRow) } - return newVals -} - -func (d *MssqlDriver) filterOutDefaultIdentityColumns( - columnsNames []string, - dataRows [][]any, - colDefaultProperties []*neosync_benthos.ColumnDefaultProperties, -) (columns []string, rows [][]any, columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties) { - newDataRows := sqlserverutil.GoTypeToSqlServerType(dataRows) - return sqlserverutil.FilterOutSqlServerDefaultIdentityColumns(d.driver, columnsNames, newDataRows, colDefaultProperties) + return true } func addPrefix(insertQuery, prefix string) string { @@ -406,34 +181,10 @@ func addSuffix(insertQuery, suffix string) string { return strings.TrimSuffix(insertQuery, ";") + ";" + suffix } -func toGoquVals(rows [][]any) []goqu.Vals { - gvals := []goqu.Vals{} - for _, row := range rows { - gval := goqu.Vals{} - for _, v := range row { - gval = append(gval, v) - } - gvals = append(gvals, gval) - } - return gvals -} - -func updateDefaultVals(rows [][]any, columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties) [][]any { - newVals := [][]any{} +func toGoquRecords(rows []map[string]any) []goqu.Record { + records := []goqu.Record{} for _, row := range rows { - newRow := []any{} - for i, a := range row { - var colDefaults *neosync_benthos.ColumnDefaultProperties - if i < len(columnDefaultProperties) { - colDefaults = columnDefaultProperties[i] - } - if colDefaults != nil && colDefaults.HasDefaultTransformer { - newRow = append(newRow, goqu.Literal(defaultStr)) - } else { - newRow = append(newRow, a) - } - } - newVals = append(newVals, newRow) + records = append(records, goqu.Record(row)) } - return newVals + return records } diff --git a/worker/pkg/query-builder/query-builder.go b/worker/pkg/query-builder/query-builder.go index 307f774580..f07da39b9d 100644 --- a/worker/pkg/query-builder/query-builder.go +++ b/worker/pkg/query-builder/query-builder.go @@ -13,8 +13,6 @@ import ( "github.com/doug-martin/goqu/v9/exp" ) -const defaultStr = "DEFAULT" - type SubsetReferenceKey struct { Table string Columns []string @@ -77,20 +75,12 @@ func BuildSelectLimitQuery( func BuildInsertQuery( driver, schema, table string, - columns []string, - values []goqu.Vals, + records []goqu.Record, onConflictDoNothing *bool, ) (sql string, args []any, err error) { builder := getGoquDialect(driver) sqltable := goqu.S(schema).Table(table) - insertCols := make([]any, len(columns)) - for i, col := range columns { - insertCols[i] = col - } - insert := builder.Insert(sqltable).Prepared(true).Cols(insertCols...) - for _, row := range values { - insert = insert.Vals(row) - } + insert := builder.Insert(sqltable).Prepared(true).Rows(records) // adds on conflict do nothing to insert query if *onConflictDoNothing { insert = insert.OnConflict(goqu.DoNothing()) @@ -103,50 +93,6 @@ func BuildInsertQuery( return query, args, nil } -// BuildPreparedQuery creates a prepared statement query template -func BuildPreparedInsertQuery( - driver, schema, table string, - columns []string, - rowCount int, - onConflictDoNothing bool, -) (string, error) { - if rowCount < 1 { - rowCount = 1 - } - - builder := getGoquDialect(driver) - sqltable := goqu.S(schema).Table(table) - - insertCols := make([]any, len(columns)) - for i, col := range columns { - insertCols[i] = col - } - - insert := builder.Insert(sqltable). - Prepared(true). - Cols(insertCols...) - - // Add placeholder rows based on rowCount - for i := 0; i < rowCount; i++ { - placeholderRow := make(goqu.Vals, len(columns)) - for j := range columns { - placeholderRow[j] = nil - } - insert = insert.Vals(placeholderRow) - } - - if onConflictDoNothing { - insert = insert.OnConflict(goqu.DoNothing()) - } - - query, _, err := insert.ToSQL() - if err != nil { - return "", err - } - - return query, nil -} - func BuildUpdateQuery( driver, schema, table string, insertColumns []string, diff --git a/worker/pkg/query-builder/query-builder_test.go b/worker/pkg/query-builder/query-builder_test.go index 00a001b17c..6f679a9273 100644 --- a/worker/pkg/query-builder/query-builder_test.go +++ b/worker/pkg/query-builder/query-builder_test.go @@ -4,11 +4,7 @@ import ( "fmt" "testing" - "github.com/doug-martin/goqu/v9" - "github.com/lib/pq" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" - "github.com/nucleuscloud/neosync/internal/testutil" - neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" "github.com/stretchr/testify/require" ) @@ -102,170 +98,74 @@ func Test_BuildUpdateQuery(t *testing.T) { func Test_BuildInsertQuery(t *testing.T) { tests := []struct { - name string - driver string - schema string - table string - columns []string - columnDataTypes []string - values [][]any - onConflictDoNothing bool - columnDefaultProperties []*neosync_benthos.ColumnDefaultProperties - expected string - expectedArgs []any + name string + driver string + schema string + table string + records []map[string]any + onConflictDoNothing bool + expected string + expectedArgs []any }{ - {"Single Column mysql", "mysql", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, []*neosync_benthos.ColumnDefaultProperties{}, "INSERT INTO `public`.`users` (`name`) VALUES (?), (?)", []any{"Alice", "Bob"}}, - {"Special characters mysql", "mysql", "public", "users.stage$dev", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, []*neosync_benthos.ColumnDefaultProperties{}, "INSERT INTO `public`.`users.stage$dev` (`name`) VALUES (?), (?)", []any{"Alice", "Bob"}}, - {"Multiple Columns mysql", "mysql", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "alice@fake.com"}, {"Bob", "bob@fake.com"}}, true, []*neosync_benthos.ColumnDefaultProperties{}, "INSERT IGNORE INTO `public`.`users` (`name`, `email`) VALUES (?, ?), (?, ?)", []any{"Alice", "alice@fake.com", "Bob", "bob@fake.com"}}, - {"Single Column postgres", "postgres", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, []*neosync_benthos.ColumnDefaultProperties{}, `INSERT INTO "public"."users" ("name") VALUES ($1), ($2)`, []any{"Alice", "Bob"}}, - {"Multiple Columns postgres", "postgres", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "alice@fake.com"}, {"Bob", "bob@fake.com"}}, true, []*neosync_benthos.ColumnDefaultProperties{}, `INSERT INTO "public"."users" ("name", "email") VALUES ($1, $2), ($3, $4) ON CONFLICT DO NOTHING`, []any{"Alice", "alice@fake.com", "Bob", "bob@fake.com"}}, + { + name: "Single Column mysql", + driver: "mysql", + schema: "public", + table: "users", + records: []map[string]any{{"name": "Alice"}, {"name": "Bob"}}, + onConflictDoNothing: false, + expected: "INSERT INTO `public`.`users` (`name`) VALUES (?), (?)", + expectedArgs: []any{"Alice", "Bob"}, + }, + { + name: "Special characters mysql", + driver: "mysql", + schema: "public", + table: "users.stage$dev", + records: []map[string]any{{"name": "Alice"}, {"name": "Bob"}}, + onConflictDoNothing: false, + expected: "INSERT INTO `public`.`users.stage$dev` (`name`) VALUES (?), (?)", + expectedArgs: []any{"Alice", "Bob"}, + }, + { + name: "Multiple Columns mysql", + driver: "mysql", + schema: "public", + table: "users", + records: []map[string]any{{"name": "Alice", "email": "alice@fake.com"}, {"name": "Bob", "email": "bob@fake.com"}}, + onConflictDoNothing: true, + expected: "INSERT IGNORE INTO `public`.`users` (`email`, `name`) VALUES (?, ?), (?, ?)", + expectedArgs: []any{"alice@fake.com", "Alice", "bob@fake.com", "Bob"}, + }, + { + name: "Single Column postgres", + driver: "postgres", + schema: "public", + table: "users", + records: []map[string]any{{"name": "Alice"}, {"name": "Bob"}}, + onConflictDoNothing: false, + expected: `INSERT INTO "public"."users" ("name") VALUES ($1), ($2)`, + expectedArgs: []any{"Alice", "Bob"}, + }, + { + name: "Multiple Columns postgres", + driver: "postgres", + schema: "public", + table: "users", + records: []map[string]any{{"name": "Alice", "email": "alice@fake.com"}, {"name": "Bob", "email": "bob@fake.com"}}, + onConflictDoNothing: true, + expected: `INSERT INTO "public"."users" ("email", "name") VALUES ($1, $2), ($3, $4) ON CONFLICT DO NOTHING`, + expectedArgs: []any{"alice@fake.com", "Alice", "bob@fake.com", "Bob"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - goquvals := toGoquVals(tt.values) - actual, args, err := BuildInsertQuery(tt.driver, tt.schema, tt.table, tt.columns, goquvals, &tt.onConflictDoNothing) + goquRows := toGoquRecords(tt.records) + actual, args, err := BuildInsertQuery(tt.driver, tt.schema, tt.table, goquRows, &tt.onConflictDoNothing) require.NoError(t, err) require.Equal(t, tt.expected, actual) require.Equal(t, tt.expectedArgs, args) }) } } - -func Test_BuildInsertQuery_JsonArray(t *testing.T) { - logger := testutil.GetTestLogger(t) - driver := sqlmanager_shared.PostgresDriver - schema := "public" - table := "test_table" - columns := []string{"id", "name", "tags"} - columnDataTypes := []string{"int", "text", "jsonb[]"} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{nil, nil, nil} - values := [][]any{ - {1, "John", []map[string]any{{"tag": "cool"}, {"tag": "awesome"}}}, - {2, "Jane", []map[string]any{{"tag": "smart"}, {"tag": "clever"}}}, - } - onConflictDoNothing := false - goquvals := toGoquVals(getPostgresVals(logger, values, columnDataTypes, columnDefaultProperties)) - - query, _, err := BuildInsertQuery(driver, schema, table, columns, goquvals, &onConflictDoNothing) - require.NoError(t, err) - expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES ($1, $2, ARRAY['{"tag":"cool"}','{"tag":"awesome"}']::jsonb[]), ($3, $4, ARRAY['{"tag":"smart"}','{"tag":"clever"}']::jsonb[])` - require.Equal(t, expectedQuery, query) -} - -func Test_BuildInsertQuery_Json(t *testing.T) { - logger := testutil.GetTestLogger(t) - driver := sqlmanager_shared.PostgresDriver - schema := "public" - table := "test_table" - columns := []string{"id", "name", "tags"} - columnDataTypes := []string{"int", "text", "json"} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{} - values := [][]any{ - {1, "John", map[string]any{"tag": "cool"}}, - {2, "Jane", map[string]any{"tag": "smart"}}, - } - onConflictDoNothing := false - - goquvals := toGoquVals(getPostgresVals(logger, values, columnDataTypes, columnDefaultProperties)) - query, args, err := BuildInsertQuery(driver, schema, table, columns, goquvals, &onConflictDoNothing) - require.NoError(t, err) - expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES ($1, $2, $3), ($4, $5, $6)` - require.Equal(t, expectedQuery, query) - require.Equal(t, []any{int64(1), "John", []byte{123, 34, 116, 97, 103, 34, 58, 34, 99, 111, 111, 108, 34, 125}, int64(2), "Jane", []byte{123, 34, 116, 97, 103, 34, 58, 34, 115, 109, 97, 114, 116, 34, 125}}, args) -} - -func TestGetGoquVals(t *testing.T) { - t.Run("Postgres", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{{"value1", 42, true, map[string]any{"key": "value"}, []int{1, 2, 3}}} - columnDataTypes := []string{"text", "integer", "boolean", "jsonb", "integer[]"} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{nil, nil, nil, nil, nil} - - result := getPostgresVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Len(t, result, 1) - row := result[0] - require.Equal(t, "value1", row[0]) - require.Equal(t, 42, row[1]) - require.Equal(t, true, row[2]) - require.JSONEq(t, `{"key":"value"}`, string(row[3].([]byte))) - require.Equal(t, pq.Array([]int{1, 2, 3}), row[4]) - }) - - t.Run("Postgres JSON", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{{"value1", 42, true, map[string]any{"key": "value"}, []int{1, 2, 3}}} - columnDataTypes := []string{"jsonb", "jsonb", "jsonb", "jsonb", "json"} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{nil, nil, nil, nil, nil} - - result := getPostgresVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Len(t, result, 1) - require.Equal(t, []any{ - []byte(`"value1"`), - []byte(`42`), - []byte(`true`), - []byte(`{"key":"value"}`), - []byte(`[1,2,3]`), - }, result[0]) - }) - - t.Run("Postgres Empty Column DataTypes", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{{"value1", 42, true, "DEFAULT"}} - columnDataTypes := []string{} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{nil, nil, nil, {HasDefaultTransformer: true}} - - result := getPostgresVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Len(t, result, 1) - row := result[0] - require.Equal(t, "value1", row[0]) - require.Equal(t, 42, row[1]) - require.Equal(t, true, row[2]) - require.Equal(t, goqu.L("DEFAULT"), row[3]) - }) - - t.Run("Mysql", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{{"value1", 42, true, "DEFAULT"}} - columnDataTypes := []string{} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{nil, nil, nil, {HasDefaultTransformer: true}} - - result := getMysqlVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Len(t, result, 1) - row := result[0] - require.Equal(t, "value1", row[0]) - require.Equal(t, 42, row[1]) - require.Equal(t, true, row[2]) - require.Equal(t, goqu.L("DEFAULT"), row[3]) - }) - - t.Run("EmptyRow", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{} - columnDataTypes := []string{} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{} - - result := getMysqlVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Empty(t, result) - }) - - t.Run("Mismatch length ColumnDataTypes and Row Values", func(t *testing.T) { - logger := testutil.GetTestLogger(t) - rows := [][]any{{"text", 42, true}} - columnDataTypes := []string{"text"} - columnDefaultProperties := []*neosync_benthos.ColumnDefaultProperties{} - - result := getMysqlVals(logger, rows, columnDataTypes, columnDefaultProperties) - - require.Len(t, result, 1) - row := result[0] - require.Equal(t, "text", row[0]) - require.Equal(t, 42, row[1]) - require.Equal(t, true, row[2]) - }) -} diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql index a1c3ccfa5d..15e677f45e 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/create-table.sql @@ -30,6 +30,7 @@ CREATE TABLE alltypes.alldatatypes ( -- Unicode character strings col_nchar NCHAR(10), col_nvarchar NVARCHAR(50), + col_json NVARCHAR(MAX), col_ntext NTEXT, -- Binary strings BROKEN diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql index a4ecdee50f..5760036acc 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/insert.sql @@ -10,7 +10,7 @@ INSERT INTO alltypes.alldatatypes ( -- Character strings col_char, col_varchar, col_text, -- Unicode character strings - col_nchar, col_nvarchar, col_ntext, + col_nchar, col_nvarchar, col_json, col_ntext, -- -- Binary strings -- col_binary, col_varbinary, col_image, -- Other data types @@ -52,6 +52,7 @@ VALUES ( -- Unicode character strings N'NCHAR ', -- NCHAR(10) N'NVARCHAR', -- NVARCHAR(50) + N'{"key": "value"}', -- JSON N'This is an NTEXT column', -- NTEXT -- -- Binary strings diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go index 468275c96c..9f3a4f60df 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/data-types/job_mappings.go @@ -186,6 +186,14 @@ func GetDefaultSyncJobMappings()[]*mgmtv1alpha1.JobMapping { Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, }, }, + { + Schema: "alltypes", + Table: "alldatatypes", + Column: "col_json", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, { Schema: "alltypes", Table: "alldatatypes", @@ -240,6 +248,7 @@ func GetTableColumnTypeMap() map[string]map[string]string { "col_text": "TEXT", "col_nchar": "NCHAR(10)", "col_nvarchar": "NVARCHAR(50)", + "col_json": "NVARCHAR(MAX)", "col_ntext": "NTEXT", "col_uniqueidentifier": "UNIQUEIDENTIFIER", "col_xml": "XML", diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mssql/simple/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mssql/simple/tests.go index 5767a7c079..2ef3472c36 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mssql/simple/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mssql/simple/tests.go @@ -109,7 +109,6 @@ func getJobmappings() []*mgmtv1alpha1.JobMapping { func getDefaultTransformerConfig() *mgmtv1alpha1.JobMappingTransformer { return &mgmtv1alpha1.JobMappingTransformer{ - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_DEFAULT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateDefaultConfig{ GenerateDefaultConfig: &mgmtv1alpha1.GenerateDefault{}, diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go index c5709d0ace..289816bc20 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go @@ -52,7 +52,6 @@ func getJobmappings() []*mgmtv1alpha1.JobMapping { Table: jm.Table, Column: jm.Column, Transformer: &mgmtv1alpha1.JobMappingTransformer{ - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_DEFAULT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateDefaultConfig{}, }, diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/job_mappings.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/job_mappings.go index 9635822a96..f244196fd8 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/job_mappings.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/job_mappings.go @@ -682,6 +682,46 @@ func GetDefaultSyncJobMappings()[]*mgmtv1alpha1.JobMapping { Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, }, }, + { + Schema: "alltypes", + Table: "products", + Column: "id", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, + { + Schema: "alltypes", + Table: "products", + Column: "price", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, + { + Schema: "alltypes", + Table: "products", + Column: "tax_rate", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, + { + Schema: "alltypes", + Table: "products", + Column: "tax_amount", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, + { + Schema: "alltypes", + Table: "products", + Column: "total_price", + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + }, + }, { Schema: "CaPiTaL", Table: "BadName", diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/setup.sql b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/setup.sql index 8e699574a2..44c01ca67b 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/setup.sql +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/setup.sql @@ -347,6 +347,18 @@ INSERT INTO alltypes.json_data (data) VALUES ( }' ); +CREATE TABLE alltypes.products ( + id SERIAL PRIMARY KEY, + price DECIMAL(10,2), + tax_rate DECIMAL(4,2), + tax_amount DECIMAL(10,2) GENERATED ALWAYS AS (price * tax_rate / 100) STORED, + total_price DECIMAL(10,2) GENERATED ALWAYS AS (price + (price * tax_rate / 100)) STORED +); + +INSERT INTO alltypes.products (price, tax_rate) VALUES + (100.00, 10.00), + (50.00, 8.50); + CREATE SCHEMA IF NOT EXISTS "CaPiTaL"; CREATE TABLE IF NOT EXISTS "CaPiTaL"."BadName" ( "ID" BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go index 88f590328a..d2b5438af2 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go @@ -1,6 +1,9 @@ package postgres_alltypes -import workflow_testdata "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/workflow/testdata" +import ( + mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + workflow_testdata "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/workflow/testdata" +) func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ @@ -9,7 +12,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { Folder: "testdata/postgres/all-types", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql", "setup.sql"}, - JobMappings: GetDefaultSyncJobMappings(), + JobMappings: getJobmappings(), JobOptions: &workflow_testdata.TestJobOptions{ Truncate: true, TruncateCascade: true, @@ -26,7 +29,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { Folder: "testdata/postgres/all-types", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql"}, - JobMappings: GetDefaultSyncJobMappings(), + JobMappings: getJobmappings(), JobOptions: &workflow_testdata.TestJobOptions{ InitSchema: true, }, @@ -39,3 +42,25 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, } } + +func getJobmappings() []*mgmtv1alpha1.JobMapping { + jobmappings := GetDefaultSyncJobMappings() + updatedJobmappings := []*mgmtv1alpha1.JobMapping{} + for _, jm := range jobmappings { + if jm.Column == "tax_amount" || jm.Column == "total_price" { + updatedJobmappings = append(updatedJobmappings, &mgmtv1alpha1.JobMapping{ + Schema: jm.Schema, + Table: jm.Table, + Column: jm.Column, + Transformer: &mgmtv1alpha1.JobMappingTransformer{ + Config: &mgmtv1alpha1.TransformerConfig{ + Config: &mgmtv1alpha1.TransformerConfig_GenerateDefaultConfig{}, + }, + }, + }) + } else { + updatedJobmappings = append(updatedJobmappings, jm) + } + } + return updatedJobmappings +}