Skip to content

Commit

Permalink
[sqlserver] Always access version table with explicit schema
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-kuck committed Oct 26, 2022
1 parent ce58221 commit c89e846
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions database/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
query := `TRUNCATE TABLE "` + ss.getMigrationTable() + `"`
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -279,7 +279,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
if dirty {
dirtyBit = 1
}
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
query = `INSERT INTO "` + ss.getMigrationTable() + `" (version, dirty) VALUES (@p1, @p2)`
if _, err := tx.Exec(query, version, dirtyBit); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -297,7 +297,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {

// Version of the current database state
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
query := `SELECT TOP 1 version, dirty FROM "` + ss.getMigrationTable() + `"`
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
Expand Down Expand Up @@ -365,10 +365,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) {
query := `IF NOT EXISTS
(SELECT *
FROM sysobjects
WHERE id = object_id(N'[` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `]')
WHERE id = object_id(N'` + ss.getMigrationTable() + `')
AND OBJECTPROPERTY(id, N'IsUserTable') = 1
)
CREATE TABLE [` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `] ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`

if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -377,6 +377,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) {
return nil
}

func (ss *SQLServer) getMigrationTable() string {
return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable)
}

func getMSITokenProvider(resource string) (func() (string, error), error) {
msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
if err != nil {
Expand Down

0 comments on commit c89e846

Please sign in to comment.