Skip to content

Commit

Permalink
feat: Implemented Inserts with ON CONFLICT DO NOTHING and more
Browse files Browse the repository at this point in the history
* Fixed error handling in the restore command.
* Now, restore jobs initiate a transaction for each table restoration.
* Implemented insert commands based on COPY format data.
* Added support for ON CONFLICT DO NOTHING.
* Implemented errors exclusion for insert-like commands
* Implemented full --exit-on-error support for data section
* Fixed an issue where an error in a subsequent restoration task could potentially rollback data inserted during a previous task
  • Loading branch information
wwoytenko committed Aug 13, 2024
1 parent 9580584 commit 03d7d7a
Show file tree
Hide file tree
Showing 14 changed files with 673 additions and 165 deletions.
3 changes: 1 addition & 2 deletions cmd/greenmask/cmd/dump/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ func init() {
Cmd.Flags().BoolP("no-tablespaces", "", false, "do not dump tablespace assignments")
Cmd.Flags().BoolP("no-toast-compression", "", false, "do not dump TOAST compression methods")
Cmd.Flags().BoolP("no-unlogged-table-data", "", false, "do not dump unlogged table data")
Cmd.Flags().BoolP("on-conflict-do-nothing", "", false, "add ON CONFLICT DO NOTHING to INSERT commands")
Cmd.Flags().BoolP("quote-all-identifiers", "", false, "quote all identifiers, even if not key words")
Cmd.Flags().StringP("section", "", "", "dump named section (pre-data, data, or post-data)")
Cmd.Flags().BoolP("serializable-deferrable", "", false, "wait until the dump can run without anomalies")
Expand Down Expand Up @@ -164,7 +163,7 @@ func init() {
"disable-triggers", "enable-row-security", "exclude-table-data", "extra-float-digits", "if-exists",
"include-foreign-data", "load-via-partition-root", "no-comments", "no-publications", "no-security-labels",
"no-subscriptions", "no-synchronized-snapshots", "no-tablespaces", "no-toast-compression",
"no-unlogged-table-data", "on-conflict-do-nothing", "quote-all-identifiers", "section",
"no-unlogged-table-data", "quote-all-identifiers", "section",
"serializable-deferrable", "snapshot", "strict-names", "use-set-session-authorization",

"dbname", "host", "port", "username",
Expand Down
8 changes: 5 additions & 3 deletions cmd/greenmask/cmd/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var (
st = st.SubStorage(dumpId, true)

restore := cmdInternals.NewRestore(
Config.Common.PgBinPath, st, &Config.Restore.PgRestoreOptions, Config.Restore.Scripts,
Config.Common.PgBinPath, st, &Config.Restore, Config.Restore.Scripts,
Config.Common.TempDirectory,
)

Expand Down Expand Up @@ -144,7 +144,7 @@ func init() {
Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.ExcludeSchema, "exclude-schema", "N", []string{}, "do not restore objects in this schema")
Cmd.Flags().StringP("no-owner", "O", "", "skip restoration of object ownership")
Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Function, "function", "P", []string{}, "restore named function")
Cmd.Flags().StringP("schema-only", "s", "", "restore only the schema, no data")
Cmd.Flags().BoolP("schema-only", "s", false, "restore only the schema, no data")
Cmd.Flags().StringP("superuser", "S", "", "superuser user name to use for disabling triggers")
Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Table, "table", "t", []string{}, "restore named relation (table, view, etc.)")
Cmd.Flags().StringSliceVarP(&Config.Restore.PgRestoreOptions.Trigger, "trigger", "T", []string{}, "restore named trigger")
Expand All @@ -163,6 +163,8 @@ func init() {
Cmd.Flags().StringP("section", "", "", "restore named section (pre-data, data, or post-data)")
Cmd.Flags().BoolP("strict-names", "", false, "restore named section (pre-data, data, or post-data) match at least one entity each")
Cmd.Flags().BoolP("use-set-session-authorization", "", false, "use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership")
Cmd.Flags().BoolP("on-conflict-do-nothing", "", false, "add ON CONFLICT DO NOTHING to INSERT commands")
Cmd.Flags().BoolP("inserts", "", false, "restore data as INSERT commands, rather than COPY")

// Connection options:
Cmd.Flags().StringP("host", "h", "/var/run/postgres", "database server host or socket directory")
Expand All @@ -176,7 +178,7 @@ func init() {
"no-owner", "function", "schema-only", "superuser", "table", "trigger", "no-privileges", "single-transaction",
"disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables",
"no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section",
"strict-names", "use-set-session-authorization",
"strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing",

"host", "port", "username",
} {
Expand Down
123 changes: 64 additions & 59 deletions internal/db/postgres/cmd/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"strconv"
"time"

"github.com/greenmaskio/greenmask/internal/domains"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -75,22 +76,24 @@ type Restore struct {
dumpIdList []int32
tocObj *toc.Toc
tmpDir string
cfg *domains.Restore

preDataClenUpToc string
postDataClenUpToc string
}

func NewRestore(
binPath string, st storages.Storager, opt *pgrestore.Options, s map[string][]pgrestore.Script, tmpDir string,
binPath string, st storages.Storager, cfg *domains.Restore, s map[string][]pgrestore.Script, tmpDir string,
) *Restore {

return &Restore{
binPath: binPath,
st: st,
pgRestore: pgrestore.NewPgRestore(binPath),
restoreOpt: opt,
restoreOpt: &cfg.PgRestoreOptions,
scripts: s,
tmpDir: path.Join(tmpDir, fmt.Sprintf("%d", time.Now().UnixNano())),
cfg: cfg,
}
}

Expand Down Expand Up @@ -359,12 +362,13 @@ func (r *Restore) dataRestore(ctx context.Context) error {
}
defer conn.Close(ctx)

if err := r.RunScripts(ctx, conn, scriptDataSection, scriptExecuteBefore); err != nil {
if err = r.RunScripts(ctx, conn, scriptDataSection, scriptExecuteBefore); err != nil {
return err
}

tasks := make(chan restorers.RestoreTask, r.restoreOpt.Jobs)
eg, gtx := errgroup.WithContext(ctx)

for j := 0; j < r.restoreOpt.Jobs; j++ {
eg.Go(func(id int) func() error {
return func() error {
Expand All @@ -373,43 +377,7 @@ func (r *Restore) dataRestore(ctx context.Context) error {
}(j))
}

eg.Go(func() error {
defer close(tasks)
for _, entry := range r.tocObj.Entries {
select {
case <-gtx.Done():
return gtx.Err()
default:
}

if entry.Section == toc.SectionData {

if !r.isNeedRestore(entry) {
continue
}

var task restorers.RestoreTask
switch *entry.Desc {
case toc.TableDataDesc:
task = restorers.NewTableRestorer(entry, r.st)
case toc.SequenceSetDesc:
task = restorers.NewSequenceRestorer(entry)
case toc.BlobsDesc:
task = restorers.NewBlobsRestorer(entry, r.st)
}

if task != nil {
select {
case <-gtx.Done():
return gtx.Err()
case tasks <- task:
}
}
}

}
return nil
})
eg.Go(r.taskPusher(gtx, tasks))

if err := eg.Wait(); err != nil {
return fmt.Errorf("at least one worker exited with error: %w", err)
Expand Down Expand Up @@ -466,7 +434,7 @@ func (r *Restore) postDataRestore(ctx context.Context) error {
}
defer conn.Close(ctx)

if err := r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteBefore); err != nil {
if err = r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteBefore); err != nil {
return err
}

Expand All @@ -478,11 +446,14 @@ func (r *Restore) postDataRestore(ctx context.Context) error {
options.DirPath = r.postDataClenUpToc
}

if err := r.pgRestore.Run(ctx, &options); err != nil {
return fmt.Errorf("cannot restore post-data section using pg_restore: %w", err)
if err = r.pgRestore.Run(ctx, &options); err != nil {
var exitErr *exec.ExitError
if r.restoreOpt.ExitOnError || (errors.As(err, &exitErr) && exitErr.ExitCode() != 1) {
return fmt.Errorf("cannot restore post-data section using pg_restore: %w", err)
}
}

if err := r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteAfter); err != nil {
if err = r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteAfter); err != nil {
return err
}

Expand Down Expand Up @@ -522,6 +493,54 @@ func (r *Restore) Run(ctx context.Context) error {
return nil
}

func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTask) func() error {
return func() error {
defer close(tasks)
for _, entry := range r.tocObj.Entries {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

if entry.Section == toc.SectionData {

if !r.isNeedRestore(entry) {
continue
}

var task restorers.RestoreTask
switch *entry.Desc {
case toc.TableDataDesc:
if r.restoreOpt.Inserts || r.restoreOpt.OnConflictDoNothing {
task = restorers.NewTableRestorerInsertFormat(
entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.OnConflictDoNothing,
r.cfg.ErrorExclusions,
)
} else {
task = restorers.NewTableRestorer(entry, r.st, r.restoreOpt.ExitOnError)
}

case toc.SequenceSetDesc:
task = restorers.NewSequenceRestorer(entry)
case toc.BlobsDesc:
task = restorers.NewBlobsRestorer(entry, r.st)
}

if task != nil {
select {
case <-ctx.Done():
return ctx.Err()
case tasks <- task:
}
}
}

}
return nil
}
}

func (r *Restore) restoreWorker(ctx context.Context, tasks <-chan restorers.RestoreTask, id int) error {
// TODO: You should execute TX for each COPY stmt
conn, err := pgx.Connect(ctx, r.dsn)
Expand Down Expand Up @@ -554,23 +573,9 @@ func (r *Restore) restoreWorker(ctx context.Context, tasks <-chan restorers.Rest
Msg("restoring")

// Open new transaction for each task
tx, err := conn.Begin(ctx)
if err != nil {
return fmt.Errorf("cannot start transaction (worker %d restoring %s): %w", id, task.DebugInfo(), err)
}
if err = task.Execute(ctx, tx); err != nil {
if txErr := tx.Rollback(ctx); txErr != nil {
log.Warn().
Err(txErr).
Int("workerId", id).
Str("objectName", task.DebugInfo()).
Msg("cannot rollback transaction")
}
if err = task.Execute(ctx, conn); err != nil {
return fmt.Errorf("unable to perform restoration task (worker %d restoring %s): %w", id, task.DebugInfo(), err)
}
if err = tx.Commit(ctx); err != nil {
return fmt.Errorf("cannot commit transaction (worker %d restoring %s): %w", id, task.DebugInfo(), err)
}
log.Debug().
Int("workerId", id).
Str("objectName", task.DebugInfo()).
Expand Down
34 changes: 34 additions & 0 deletions internal/db/postgres/pgcopy/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type columnPos struct {

const defaultBufferPoolSize = 128
const defaultDecodedBuf = 1024
const UseDynamicSize = -1

// Row - the row driver that works with vanilla COPY format
type Row struct {
Expand All @@ -46,9 +47,21 @@ type Row struct {
newValues []*toolkit.RawValue
// columnPos - list of the column pos within the raw data
columnPos []*columnPos
// columnPos - list of the column pos within the raw data
tupleSize int
// isDynamic - flag that indicates that row size will be determined in runtime
isDynamic bool
}

func NewRow(tupleSize int) *Row {
var isDynamic bool
if tupleSize == 0 {
panic("tuple size should be greater than 0")
}
if tupleSize == UseDynamicSize {
tupleSize = 0
isDynamic = true
}
pos := make([]*columnPos, tupleSize)
decodeBufferPool := make([][]byte, tupleSize)
encodeBufferPool := make([][]byte, tupleSize)
Expand All @@ -65,6 +78,8 @@ func NewRow(tupleSize int) *Row {
decodeBufferPool: decodeBufferPool,
encodeBufferPool: encodeBufferPool,
encoded: make([]byte, 0, defaultDecodedBuf),
tupleSize: tupleSize,
isDynamic: isDynamic,
}
}

Expand All @@ -81,6 +96,9 @@ func (r *Row) Decode(raw []byte) error {
} else {
colEndPos = colStartPos + colEndPos
}
if r.isDynamic && idx >= r.tupleSize {
r.appendNewEmptyBuffer()
}

p := r.columnPos[idx]
p.start = colStartPos
Expand Down Expand Up @@ -109,6 +127,14 @@ func (r *Row) GetColumn(idx int) (*toolkit.RawValue, error) {
return res, nil
}

func (r *Row) GetColumnRaw(idx int) ([]byte, error) {
if len(r.columnPos) <= idx {
return nil, ErrIndexOutOfRage
}
pos := r.columnPos[idx]
return r.raw[pos.start:pos.end], nil
}

// SetColumn - set column (replace original) value and decode it later
func (r *Row) SetColumn(idx int, v *toolkit.RawValue) error {
if idx > len(r.columnPos)-1 {
Expand Down Expand Up @@ -144,6 +170,14 @@ func (r *Row) Encode() ([]byte, error) {
return res, nil
}

func (r *Row) appendNewEmptyBuffer() {
r.columnPos = append(r.columnPos, &columnPos{})
r.decodeBufferPool = append(r.decodeBufferPool, make([]byte, defaultBufferPoolSize))
r.encodeBufferPool = append(r.encodeBufferPool, make([]byte, defaultBufferPoolSize))
r.newValues = append(r.newValues, nil)
r.tupleSize++
}

func (r *Row) Length() int {
return len(r.columnPos)
}
Expand Down
Loading

0 comments on commit 03d7d7a

Please sign in to comment.