Skip to content

Commit

Permalink
Fix mysql date and datetime in CLI sync (#2893)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Nov 1, 2024
1 parent 1745f83 commit 6242842
Show file tree
Hide file tree
Showing 20 changed files with 486 additions and 154 deletions.
8 changes: 5 additions & 3 deletions backend/pkg/dbconnect-config/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ func NewFromMysqlConnection(
config *mgmtv1alpha1.ConnectionConfig_MysqlConfig,
connectionTimeout *uint32,
logger *slog.Logger,
mysqlDisableParseTime bool,
) (DbConnectConfig, error) {
parseTime := !mysqlDisableParseTime
switch cc := config.MysqlConfig.GetConnectionConfig().(type) {
case *mgmtv1alpha1.MysqlConnectionConfig_Connection:
cfg := mysql.NewConfig()
Expand All @@ -46,7 +48,7 @@ func NewFromMysqlConnection(
}
cfg.Net = cc.Connection.GetProtocol()
cfg.MultiStatements = true
cfg.ParseTime = true
cfg.ParseTime = parseTime

return &mysqlConnectConfig{dsn: cfg.FormatDSN(), user: cfg.User}, nil
case *mgmtv1alpha1.MysqlConnectionConfig_Url:
Expand Down Expand Up @@ -76,7 +78,7 @@ func NewFromMysqlConnection(
cfg.Timeout = time.Duration(*connectionTimeout) * time.Second
}
cfg.MultiStatements = true
cfg.ParseTime = true
cfg.ParseTime = parseTime
for k, values := range uriConfig.Query() {
for _, value := range values {
cfg.Params[k] = value
Expand All @@ -89,7 +91,7 @@ func NewFromMysqlConnection(
cfg.Timeout = time.Duration(*connectionTimeout) * time.Second
}
cfg.MultiStatements = true
cfg.ParseTime = true
cfg.ParseTime = parseTime
return &mysqlConnectConfig{dsn: cfg.FormatDSN(), user: cfg.User}, nil
default:
return nil, fmt.Errorf("unsupported mysql connection config: %T", cc)
Expand Down
30 changes: 30 additions & 0 deletions backend/pkg/dbconnect-config/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
&testConnectionTimeout,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -45,6 +46,29 @@ func Test_NewFromMysqlConnection(t *testing.T) {
)
assert.Equal(t, "test-user", actual.GetUser())
})

t.Run("ok_disable_parse_time", func(t *testing.T) {
actual, err := NewFromMysqlConnection(
&mgmtv1alpha1.ConnectionConfig_MysqlConfig{
MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{
Connection: mysqlconnectionFixture,
},
},
},
&testConnectionTimeout,
discardLogger,
true,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(
t,
"test-user:test-pass@tcp(localhost:3309)/mydb?multiStatements=true&timeout=5s",
actual.String(),
)
assert.Equal(t, "test-user", actual.GetUser())
})
t.Run("ok_no_timeout", func(t *testing.T) {
actual, err := NewFromMysqlConnection(
&mgmtv1alpha1.ConnectionConfig_MysqlConfig{
Expand All @@ -56,6 +80,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
nil,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -80,6 +105,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
&testConnectionTimeout,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -101,6 +127,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
nil,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -122,6 +149,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
nil,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -146,6 +174,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
&testConnectionTimeout,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand All @@ -167,6 +196,7 @@ func Test_NewFromMysqlConnection(t *testing.T) {
},
nil,
discardLogger,
false,
)
assert.NoError(t, err)
assert.NotNil(t, actual)
Expand Down
43 changes: 29 additions & 14 deletions backend/pkg/sqlconnect/mock_SqlConnector.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 32 additions & 4 deletions backend/pkg/sqlconnect/sql-connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,45 @@ type SqlDBTX interface {
BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
}

type SqlConnectorOption func(*sqlConnectorOptions)

type sqlConnectorOptions struct {
mysqlDisableParseTime bool
postgresDriver string
}

// WithMysqlParseTimeDisabled disables MySQL time parsing
func WithMysqlParseTimeDisabled() SqlConnectorOption {
return func(opts *sqlConnectorOptions) {
opts.mysqlDisableParseTime = true
}
}

// WithPostgresDriver overrides default postgres driver
func WithDefaultPostgresDriver() SqlConnectorOption {
return func(opts *sqlConnectorOptions) {
opts.postgresDriver = "postgres"
}
}

type SqlConnector interface {
NewDbFromConnectionConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (SqlDbContainer, error)
NewDbFromConnectionConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger, opts ...SqlConnectorOption) (SqlDbContainer, error)
}

type SqlOpenConnector struct{}

func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (SqlDbContainer, error) {
func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger, opts ...SqlConnectorOption) (SqlDbContainer, error) {
if cc == nil {
return nil, errors.New("connectionConfig was nil, expected *mgmtv1alpha1.ConnectionConfig")
}

options := sqlConnectorOptions{
postgresDriver: "pgx",
}
for _, opt := range opts {
opt(&options)
}

dbconnopts, err := getConnectionOptsFromConnectionConfig(cc)
if err != nil {
return nil, err
Expand Down Expand Up @@ -76,10 +104,10 @@ func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.Connectio
dbconnopts,
), nil
} else {
return newStdlibContainer("pgx", dsn, dbconnopts), nil
return newStdlibContainer(options.postgresDriver, dsn, dbconnopts), nil
}
case *mgmtv1alpha1.ConnectionConfig_MysqlConfig:
connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout, logger)
connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout, logger, options.mysqlDisableParseTime)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors"
neosync_gcp "github.com/nucleuscloud/neosync/backend/internal/gcp"
"github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
"github.com/nucleuscloud/neosync/backend/pkg/sqlconnect"
sqlmanager_mysql "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/mysql"
sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
Expand Down Expand Up @@ -87,7 +88,7 @@ func (s *Service) GetConnectionDataStream(
return err
}

conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.ConnectionConfig, &connectionTimeout, logger)
conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.ConnectionConfig, &connectionTimeout, logger, sqlconnect.WithMysqlParseTimeDisabled())
if err != nil {
return err
}
Expand Down Expand Up @@ -148,7 +149,7 @@ func (s *Service) GetConnectionDataStream(
return err
}

conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), &connectionTimeout, logger)
conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), &connectionTimeout, logger, sqlconnect.WithDefaultPostgresDriver())
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func getDbRoleFromConnectionConfig(cconfig *mgmtv1alpha1.ConnectionConfig, logge
}
return parsedCfg.GetUser(), nil
case *mgmtv1alpha1.ConnectionConfig_MysqlConfig:
parsedCfg, err := dbconnectconfig.NewFromMysqlConnection(typedconfig, nil, logger)
parsedCfg, err := dbconnectconfig.NewFromMysqlConnection(typedconfig, nil, logger, false)
if err != nil {
return "", fmt.Errorf("unable to parse mysql connection: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/internal/cmds/neosync/sync/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
)

func createJob(
func toJob(
cmd *cmdConfig,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
Expand Down
6 changes: 5 additions & 1 deletion cli/internal/cmds/neosync/sync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ func (c *clisync) configureAndRunSync() error {
case <-ctx.Done():
return
case <-stopChan:
c.logger.Error("Sync Failed.")
cancel()
os.Exit(1)
return
}
}
Expand Down Expand Up @@ -369,6 +371,8 @@ func (c *clisync) configureSync() ([][]*benthosbuilder.BenthosConfigResponse, er
if syncConfigs == nil {
return nil, nil
}

// TODO move this after benthos builder
c.logger.Info("Running table init statements...")
err = c.runDestinationInitStatements(syncConfigs, schemaConfig)
if err != nil {
Expand All @@ -378,7 +382,7 @@ func (c *clisync) configureSync() ([][]*benthosbuilder.BenthosConfigResponse, er
syncConfigCount := len(syncConfigs)
c.logger.Info(fmt.Sprintf("Generating %d sync configs...", syncConfigCount))

job, err := createJob(c.cmd, c.sourceConnection, c.destinationConnection, schemaConfig.Schemas)
job, err := toJob(c.cmd, c.sourceConnection, c.destinationConnection, schemaConfig.Schemas)
if err != nil {
c.logger.Error("unable to create job")
return nil, err
Expand Down
17 changes: 11 additions & 6 deletions cli/internal/cmds/neosync/sync/sync_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ func Test_Sync(t *testing.T) {
panic(err)
}

testdataFolder := "../../../../../internal/testutil/testdata/postgres/humanresources"
err = postgres.Source.RunSqlFiles(ctx, &testdataFolder, []string{"create-tables.sql"})
testdataFolder := "../../../../../internal/testutil/testdata/postgres"
err = postgres.Source.RunSqlFiles(ctx, &testdataFolder, []string{"humanresources/create-tables.sql"})
if err != nil {
panic(err)
}
err = postgres.Target.RunSqlFiles(ctx, &testdataFolder, []string{"create-schema.sql"})
err = postgres.Target.RunSqlFiles(ctx, &testdataFolder, []string{"humanresources/create-schema.sql"})
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -106,12 +106,12 @@ func Test_Sync(t *testing.T) {
panic(err)
}

testdataFolder := "../../../../../internal/testutil/testdata/mysql/humanresources"
err = mysql.Source.RunSqlFiles(ctx, &testdataFolder, []string{"create-tables.sql"})
testdataFolder := "../../../../../internal/testutil/testdata/mysql"
err = mysql.Source.RunSqlFiles(ctx, &testdataFolder, []string{"humanresources/create-tables.sql", "alltypes/create-tables.sql"})
if err != nil {
panic(err)
}
err = mysql.Target.RunSqlFiles(ctx, &testdataFolder, []string{"create-schema.sql"})
err = mysql.Target.RunSqlFiles(ctx, &testdataFolder, []string{"humanresources/create-schema.sql", "alltypes/create-schema.sql"})
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -152,6 +152,11 @@ func Test_Sync(t *testing.T) {
err = rows.Scan(&rowCount)
require.NoError(t, err)
require.Greater(t, rowCount, 1)

rows = mysql.Target.DB.QueryRowContext(ctx, "select count(*) from alltypes.all_data_types;")
err = rows.Scan(&rowCount)
require.NoError(t, err)
require.Greater(t, rowCount, 1)
})

t.Cleanup(func() {
Expand Down
Loading

0 comments on commit 6242842

Please sign in to comment.