diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 47c524bb27..7f213be924 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -2,6 +2,7 @@ package connpostgres import ( "context" + "encoding/binary" "fmt" "log/slog" "sync" @@ -15,6 +16,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/lib/pq/oid" "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/log" "github.com/PeerDB-io/peer-flow/connectors/utils" geo "github.com/PeerDB-io/peer-flow/datatypes" @@ -25,23 +27,19 @@ import ( "github.com/PeerDB-io/peer-flow/shared" ) +type TxBuffer struct { + Streams [][]byte + Lsn pglogrepl.LSN + FirstSegment bool +} + type PostgresCDCSource struct { *PostgresConnector - srcTableIDNameMapping map[uint32]string - tableNameMapping map[string]model.NameAndExclude - tableNameSchemaMapping map[string]*protos.TableSchema - relationMessageMapping model.RelationMessageMapping - slot string - publication string - typeMap *pgtype.Map - commitLock *pglogrepl.BeginMessage - - // for partitioned tables, maps child relid to parent relid - childToParentRelIDMapping map[uint32]uint32 - - // for storing chema delta audit logs to catalog - catalogPool *pgxpool.Pool - flowJobName string + *PostgresCDCConfig + typeMap *pgtype.Map + commitLock *pglogrepl.BeginMessage + txBuffer map[uint32]*TxBuffer + inStream bool } type PostgresCDCConfig struct { @@ -49,28 +47,29 @@ type PostgresCDCConfig struct { SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude TableNameSchemaMapping map[string]*protos.TableSchema - ChildToParentRelIDMap map[uint32]uint32 + // for partitioned tables, maps child relid to parent relid + ChildToParentRelIDMap map[uint32]uint32 + // for storing schema delta audit logs to catalog RelationMessageMapping model.RelationMessageMapping FlowJobName string Slot string Publication string + Version int32 } // Create a new PostgresCDCSource func (c *PostgresConnector) NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) *PostgresCDCSource { + var txBuffer map[uint32]*TxBuffer + if cdcConfig.Version >= 2 { + txBuffer = make(map[uint32]*TxBuffer) + } return &PostgresCDCSource{ - PostgresConnector: c, - srcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping, - tableNameMapping: cdcConfig.TableNameMapping, - tableNameSchemaMapping: cdcConfig.TableNameSchemaMapping, - relationMessageMapping: cdcConfig.RelationMessageMapping, - slot: cdcConfig.Slot, - publication: cdcConfig.Publication, - childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap, - typeMap: pgtype.NewMap(), - commitLock: nil, - catalogPool: cdcConfig.CatalogPool, - flowJobName: cdcConfig.FlowJobName, + PostgresConnector: c, + PostgresCDCConfig: cdcConfig, + typeMap: pgtype.NewMap(), + commitLock: nil, + inStream: false, + txBuffer: txBuffer, } } @@ -283,6 +282,35 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma return qvalue.QValueString{Val: string(data)}, nil } +type cdcRecordProcessor[Items model.Items] struct { + recordStore *utils.CdcStore[Items] + records *model.CDCStream[Items] + pullRequest *model.PullRecordsRequest[Items] + processor replProcessor[Items] + nextStandbyMessageDeadline time.Time +} + +func (rp *cdcRecordProcessor[Items]) addRecordWithKey( + ctx context.Context, + logger log.Logger, + key model.TableWithPkey, + rec model.Record[Items], +) error { + if err := rp.recordStore.Set(logger, key, rec); err != nil { + return err + } + if err := rp.records.AddRecord(ctx, rec); err != nil { + return err + } + + if rp.recordStore.Len() == 1 { + rp.records.SignalAsNotEmpty() + rp.nextStandbyMessageDeadline = time.Now().Add(rp.pullRequest.IdleTimeout) + logger.Info(fmt.Sprintf("pushing the standby deadline to %s", rp.nextStandbyMessageDeadline)) + } + return nil +} + // PullCdcRecords pulls records from req's cdc stream func PullCdcRecords[Items model.Items]( ctx context.Context, @@ -318,43 +346,31 @@ func PullCdcRecords[Items model.Items]( } var standByLastLogged time.Time - cdcRecordsStorage, err := utils.NewCDCStore[Items](ctx, p.flowJobName) + cdcRecordStore, err := utils.NewCDCStore[Items](ctx, p.FlowJobName) if err != nil { return err } defer func() { - if cdcRecordsStorage.IsEmpty() { + if cdcRecordStore.IsEmpty() { records.SignalAsEmpty() } - logger.Info(fmt.Sprintf("[finished] PullRecords streamed %d records", cdcRecordsStorage.Len())) - err := cdcRecordsStorage.Close() - if err != nil { + logger.Info(fmt.Sprintf("[finished] PullRecords streamed %d records", cdcRecordStore.Len())) + if err := cdcRecordStore.Close(); err != nil { logger.Warn("failed to clean up records storage", slog.Any("error", err)) } }() shutdown := shared.Interval(ctx, time.Minute, func() { - logger.Info(fmt.Sprintf("pulling records, currently have %d records", cdcRecordsStorage.Len())) + logger.Info(fmt.Sprintf("pulling records, currently have %d records", cdcRecordStore.Len())) }) defer shutdown() - standbyMessageTimeout := req.IdleTimeout - nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) - - addRecordWithKey := func(key model.TableWithPkey, rec model.Record[Items]) error { - if err := cdcRecordsStorage.Set(logger, key, rec); err != nil { - return err - } - if err := records.AddRecord(ctx, rec); err != nil { - return err - } - - if cdcRecordsStorage.Len() == 1 { - records.SignalAsNotEmpty() - nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) - logger.Info(fmt.Sprintf("pushing the standby deadline to %s", nextStandbyMessageDeadline)) - } - return nil + recordProcessor := cdcRecordProcessor[Items]{ + recordStore: cdcRecordStore, + records: records, + nextStandbyMessageDeadline: time.Now().Add(req.IdleTimeout), + pullRequest: req, + processor: processor, } pkmRequiresResponse := false @@ -362,64 +378,63 @@ func PullCdcRecords[Items model.Items]( for { if pkmRequiresResponse { - err := sendStandbyAfterReplLock("pkm-response") - if err != nil { + if err := sendStandbyAfterReplLock("pkm-response"); err != nil { return err } pkmRequiresResponse = false if time.Since(standByLastLogged) > 10*time.Second { - numRowsProcessedMessage := fmt.Sprintf("processed %d rows", cdcRecordsStorage.Len()) + numRowsProcessedMessage := fmt.Sprintf("processed %d rows", cdcRecordStore.Len()) logger.Info("Sent Standby status message. " + numRowsProcessedMessage) standByLastLogged = time.Now() } } if p.commitLock == nil { - cdclen := cdcRecordsStorage.Len() + cdclen := cdcRecordStore.Len() if cdclen >= 0 && uint32(cdclen) >= req.MaxBatchSize { - return nil + break } if waitingForCommit { logger.Info(fmt.Sprintf( "[%s] commit received, returning currently accumulated records - %d", - p.flowJobName, - cdcRecordsStorage.Len()), + p.FlowJobName, + cdcRecordStore.Len()), ) - return nil + break } } // if we are past the next standby deadline (?) - if time.Now().After(nextStandbyMessageDeadline) { - if !cdcRecordsStorage.IsEmpty() { - logger.Info(fmt.Sprintf("standby deadline reached, have %d records", cdcRecordsStorage.Len())) + if time.Now().After(recordProcessor.nextStandbyMessageDeadline) { + if !cdcRecordStore.IsEmpty() { + logger.Info(fmt.Sprintf("standby deadline reached, have %d records", cdcRecordStore.Len())) if p.commitLock == nil { logger.Info( fmt.Sprintf("no commit lock, returning currently accumulated records - %d", - cdcRecordsStorage.Len())) + cdcRecordStore.Len())) return nil } else { logger.Info(fmt.Sprintf("commit lock, waiting for commit to return records - %d", - cdcRecordsStorage.Len())) + cdcRecordStore.Len())) waitingForCommit = true } } else { logger.Info(fmt.Sprintf("[%s] standby deadline reached, no records accumulated, continuing to wait", - p.flowJobName), + p.FlowJobName), ) } - nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) + recordProcessor.nextStandbyMessageDeadline = time.Now().Add(req.IdleTimeout) } var receiveCtx context.Context var cancel context.CancelFunc - if cdcRecordsStorage.IsEmpty() { + if cdcRecordStore.IsEmpty() { receiveCtx, cancel = context.WithCancel(ctx) } else { - receiveCtx, cancel = context.WithDeadline(ctx, nextStandbyMessageDeadline) + receiveCtx, cancel = context.WithDeadline(ctx, recordProcessor.nextStandbyMessageDeadline) } rawMsg, err := func() (pgproto3.BackendMessage, error) { replLock.Lock() @@ -436,8 +451,8 @@ func PullCdcRecords[Items model.Items]( if err != nil && p.commitLock == nil { if pgconn.Timeout(err) { logger.Info(fmt.Sprintf("Stand-by deadline reached, returning currently accumulated records - %d", - cdcRecordsStorage.Len())) - return nil + cdcRecordStore.Len())) + break } else { return fmt.Errorf("ReceiveMessage failed: %w", err) } @@ -480,126 +495,30 @@ func PullCdcRecords[Items model.Items]( logger.Debug(fmt.Sprintf("XLogData => WALStart %s ServerWALEnd %s ServerTime %s\n", xld.WALStart, xld.ServerWALEnd, xld.ServerTime)) - rec, err := processMessage(ctx, p, records, xld, clientXLogPos, processor) - if err != nil { + if err := recordProcessor.processXLogData(ctx, p, xld, msg.Data[1:], clientXLogPos); err != nil { return fmt.Errorf("error processing message: %w", err) } - if rec != nil { - tableName := rec.GetDestinationTableName() - switch r := rec.(type) { - case *model.UpdateRecord[Items]: - // tableName here is destination tableName. - // should be ideally sourceTableName as we are in PullRecords. - // will change in future - isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull - if isFullReplica { - err := addRecordWithKey(model.TableWithPkey{}, rec) - if err != nil { - return err - } - } else { - tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) - if err != nil { - return err - } - - latestRecord, ok, err := cdcRecordsStorage.Get(tablePkeyVal) - if err != nil { - return err - } - if !ok { - err = addRecordWithKey(tablePkeyVal, rec) - } else { - // iterate through unchanged toast cols and set them in new record - updatedCols := r.NewItems.UpdateIfNotExists(latestRecord.GetItems()) - for _, col := range updatedCols { - delete(r.UnchangedToastColumns, col) - } - err = addRecordWithKey(tablePkeyVal, rec) - } - if err != nil { - return err - } - } - - case *model.InsertRecord[Items]: - isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull - if isFullReplica { - err := addRecordWithKey(model.TableWithPkey{}, rec) - if err != nil { - return err - } - } else { - tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) - if err != nil { - return err - } - - err = addRecordWithKey(tablePkeyVal, rec) - if err != nil { - return err - } - } - case *model.DeleteRecord[Items]: - isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull - if isFullReplica { - err := addRecordWithKey(model.TableWithPkey{}, rec) - if err != nil { - return err - } - } else { - tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) - if err != nil { - return err - } - - latestRecord, ok, err := cdcRecordsStorage.Get(tablePkeyVal) - if err != nil { - return err - } - if ok { - r.Items = latestRecord.GetItems() - if updateRecord, ok := latestRecord.(*model.UpdateRecord[Items]); ok { - r.UnchangedToastColumns = updateRecord.UnchangedToastColumns - } - } else { - // there is nothing to backfill the items in the delete record with, - // so don't update the row with this record - // add sentinel value to prevent update statements from selecting - r.UnchangedToastColumns = map[string]struct{}{ - "_peerdb_not_backfilled_delete": {}, - } - } - - // A delete can only be followed by an INSERT, which does not need backfilling - // No need to store DeleteRecords in memory or disk. - err = addRecordWithKey(model.TableWithPkey{}, rec) - if err != nil { - return err - } - } - - case *model.RelationRecord[Items]: - tableSchemaDelta := r.TableSchemaDelta - if len(tableSchemaDelta.AddedColumns) > 0 { - logger.Info(fmt.Sprintf("Detected schema change for table %s, addedColumns: %v", - tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns)) - records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) - } - - case *model.MessageRecord[Items]: - if err := addRecordWithKey(model.TableWithPkey{}, rec); err != nil { - return err - } - } - } - if xld.WALStart > clientXLogPos { clientXLogPos = xld.WALStart } } } + + for xid, txbuf := range p.txBuffer { + if _, err := p.CatalogPool.Exec( + ctx, + "insert into v2cdc (flow_name, xid, lsn, stream) values ($1, $2, $3, $4) on conflict do nothing", + p.FlowJobName, + xid, + txbuf.Lsn, + txbuf.Streams, + ); err != nil { + return err + } + } + + return nil } func (p *PostgresCDCSource) baseRecord(lsn pglogrepl.LSN) model.BaseRecord { @@ -613,82 +532,159 @@ func (p *PostgresCDCSource) baseRecord(lsn pglogrepl.LSN) model.BaseRecord { } } -func processMessage[Items model.Items]( +func (rp *cdcRecordProcessor[Items]) processXLogData( ctx context.Context, p *PostgresCDCSource, - batch *model.CDCStream[Items], xld pglogrepl.XLogData, + xldbytes []byte, currentClientXlogPos pglogrepl.LSN, - processor replProcessor[Items], -) (model.Record[Items], error) { - logger := logger.LoggerFromCtx(ctx) - logicalMsg, err := pglogrepl.Parse(xld.WALData) +) error { + var logicalMsg pglogrepl.Message + var err error + if p.Version < 2 { + logicalMsg, err = pglogrepl.Parse(xld.WALData) + } else { + if p.inStream && + (xld.WALData[0] == byte(pglogrepl.MessageTypeUpdate) || + xld.WALData[0] == byte(pglogrepl.MessageTypeInsert) || + xld.WALData[0] == byte(pglogrepl.MessageTypeDelete) || + xld.WALData[0] == byte(pglogrepl.MessageTypeRelation)) { + xid := binary.BigEndian.Uint32(xld.WALData[1:]) + txbuf := p.txBuffer[xid] + txbuf.Streams = append(txbuf.Streams, xldbytes) + } else { + logicalMsg, err = pglogrepl.ParseV2(xld.WALData, p.inStream) + } + } if err != nil { - return nil, fmt.Errorf("error parsing logical message: %w", err) + return fmt.Errorf("error parsing logical message: %w", err) } + return rp.processMessage(ctx, p, xld.WALStart, logicalMsg, currentClientXlogPos) +} + +func (rp *cdcRecordProcessor[Items]) processMessage( + ctx context.Context, + p *PostgresCDCSource, + lsn pglogrepl.LSN, + logicalMsg pglogrepl.Message, + currentClientXlogPos pglogrepl.LSN, +) error { + logger := logger.LoggerFromCtx(ctx) switch msg := logicalMsg.(type) { case *pglogrepl.BeginMessage: logger.Debug("BeginMessage", slog.Any("FinalLSN", msg.FinalLSN), slog.Any("XID", msg.Xid)) p.commitLock = msg case *pglogrepl.InsertMessage: - return processInsertMessage(p, xld.WALStart, msg, processor) + return rp.processInsertMessage(ctx, p, lsn, msg) + case *pglogrepl.InsertMessageV2: + return rp.processInsertMessage(ctx, p, lsn, &msg.InsertMessage) case *pglogrepl.UpdateMessage: - return processUpdateMessage(p, xld.WALStart, msg, processor) + return rp.processUpdateMessage(ctx, p, lsn, msg) + case *pglogrepl.UpdateMessageV2: + return rp.processUpdateMessage(ctx, p, lsn, &msg.UpdateMessage) case *pglogrepl.DeleteMessage: - return processDeleteMessage(p, xld.WALStart, msg, processor) + return rp.processDeleteMessage(ctx, p, lsn, msg) + case *pglogrepl.DeleteMessageV2: + return rp.processDeleteMessage(ctx, p, lsn, &msg.DeleteMessage) + case *pglogrepl.RelationMessage: + return rp.processRelationMessage(ctx, p, currentClientXlogPos, msg) + case *pglogrepl.RelationMessageV2: + return rp.processRelationMessage(ctx, p, currentClientXlogPos, &msg.RelationMessage) + case *pglogrepl.LogicalDecodingMessage: + return rp.processLogicalDecodingMessage(ctx, p, lsn, msg) + case *pglogrepl.LogicalDecodingMessageV2: + return rp.processLogicalDecodingMessage(ctx, p, lsn, &msg.LogicalDecodingMessage) case *pglogrepl.CommitMessage: // for a commit message, update the last checkpoint id for the record batch. logger.Debug("CommitMessage", slog.Any("CommitLSN", msg.CommitLSN), slog.Any("TransactionEndLSN", msg.TransactionEndLSN)) - batch.UpdateLatestCheckpoint(int64(msg.CommitLSN)) + rp.records.UpdateLatestCheckpoint(int64(msg.CommitLSN)) p.commitLock = nil - case *pglogrepl.RelationMessage: - // treat all relation messages as corresponding to parent if partitioned. - msg.RelationID = p.getParentRelIDIfPartitioned(msg.RelationID) + case *pglogrepl.StreamCommitMessageV2: + txbuf := p.txBuffer[msg.Xid] + if !txbuf.FirstSegment { + rows, err := p.CatalogPool.Query(ctx, + "select stream from v2cdc where flow_name = $1 and xid = $2 order by lsn", + p.FlowJobName, msg.Xid) + if err != nil { + return err + } + for rows.Next() { + var stream [][]byte + if err := rows.Scan(&stream); err != nil { + return err + } - if _, exists := p.srcTableIDNameMapping[msg.RelationID]; !exists { - return nil, nil + for _, m := range stream { + mxld, err := pglogrepl.ParseXLogData(m) + if err != nil { + return err + } + logicalMsg, err = pglogrepl.ParseV2(mxld.WALData, p.inStream) + if err != nil { + return err + } + if err := rp.processMessage(ctx, p, mxld.WALStart, logicalMsg, currentClientXlogPos); err != nil { + return err + } + } + } + if err := rows.Err(); err != nil { + return err + } } - logger.Debug("RelationMessage", - slog.Any("RelationID", msg.RelationID), - slog.String("Namespace", msg.Namespace), - slog.String("RelationName", msg.RelationName), - slog.Any("Columns", msg.Columns)) - - return processRelationMessage[Items](ctx, p, currentClientXlogPos, msg) - case *pglogrepl.LogicalDecodingMessage: - logger.Info("LogicalDecodingMessage", - slog.Bool("Transactional", msg.Transactional), - slog.String("Prefix", msg.Prefix), - slog.Int64("LSN", int64(msg.LSN))) - if !msg.Transactional { - batch.UpdateLatestCheckpoint(int64(msg.LSN)) - } - return &model.MessageRecord[Items]{ - BaseRecord: p.baseRecord(msg.LSN), - Prefix: msg.Prefix, - Content: string(msg.Content), - }, nil + for _, m := range txbuf.Streams { + mxld, err := pglogrepl.ParseXLogData(m) + if err != nil { + return err + } + logicalMsg, err = pglogrepl.ParseV2(mxld.WALData, p.inStream) + if err != nil { + return err + } + if err := rp.processMessage(ctx, p, mxld.WALStart, logicalMsg, currentClientXlogPos); err != nil { + return err + } + } + rp.records.UpdateLatestCheckpoint(int64(msg.CommitLSN)) + delete(p.txBuffer, msg.Xid) + case *pglogrepl.StreamAbortMessageV2: + if txbuf, ok := p.txBuffer[msg.Xid]; ok && !txbuf.FirstSegment { + if _, err := p.CatalogPool.Exec(ctx, + "delete from v2cdc where flow_name = $1 and xid = $2", + p.FlowJobName, msg.Xid, + ); err != nil { + return err + } + } + delete(p.txBuffer, msg.Xid) + case *pglogrepl.StreamStartMessageV2: + if _, ok := p.txBuffer[msg.Xid]; !ok { + p.txBuffer[msg.Xid] = &TxBuffer{Lsn: lsn, FirstSegment: msg.FirstSegment != 0} + } + p.inStream = true + case *pglogrepl.StreamStopMessageV2: + p.inStream = false default: logger.Warn(fmt.Sprintf("%T not supported", msg)) } - return nil, nil + return nil } -func processInsertMessage[Items model.Items]( +func (rp *cdcRecordProcessor[Items]) processInsertMessage( + ctx context.Context, p *PostgresCDCSource, lsn pglogrepl.LSN, msg *pglogrepl.InsertMessage, - processor replProcessor[Items], -) (model.Record[Items], error) { +) error { relID := p.getParentRelIDIfPartitioned(msg.RelationID) - tableName, exists := p.srcTableIDNameMapping[relID] + tableName, exists := p.SrcTableIDNameMapping[relID] if !exists { - return nil, nil + return nil } // log lsn and relation id for debugging @@ -697,34 +693,51 @@ func processInsertMessage[Items model.Items]( rel, ok := p.relationMessageMapping[relID] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", relID) + return fmt.Errorf("unknown relation id: %d", relID) } - items, _, err := processTuple(processor, p, msg.Tuple, rel, p.tableNameMapping[tableName].Exclude) + items, _, err := processTuple(rp.processor, p, msg.Tuple, rel, p.TableNameMapping[tableName].Exclude) if err != nil { - return nil, fmt.Errorf("error converting tuple to map: %w", err) + return fmt.Errorf("error converting tuple to map: %w", err) } - return &model.InsertRecord[Items]{ + rec := &model.InsertRecord[Items]{ BaseRecord: p.baseRecord(lsn), Items: items, - DestinationTableName: p.tableNameMapping[tableName].Name, + DestinationTableName: p.TableNameMapping[tableName].Name, SourceTableName: tableName, - }, nil + } + isFullReplica := rp.pullRequest.TableNameSchemaMapping[tableName].IsReplicaIdentityFull + if isFullReplica { + err := rp.addRecordWithKey(ctx, p.logger, model.TableWithPkey{}, rec) + if err != nil { + return err + } + } else { + tablePkeyVal, err := model.RecToTablePKey(rp.pullRequest.TableNameSchemaMapping, rec) + if err != nil { + return err + } + + err = rp.addRecordWithKey(ctx, p.logger, tablePkeyVal, rec) + if err != nil { + return err + } + } + return nil } -// processUpdateMessage processes an update message and returns an UpdateRecord -func processUpdateMessage[Items model.Items]( +func (rp *cdcRecordProcessor[Items]) processUpdateMessage( + ctx context.Context, p *PostgresCDCSource, lsn pglogrepl.LSN, msg *pglogrepl.UpdateMessage, - processor replProcessor[Items], -) (model.Record[Items], error) { +) error { relID := p.getParentRelIDIfPartitioned(msg.RelationID) - tableName, exists := p.srcTableIDNameMapping[relID] + tableName, exists := p.SrcTableIDNameMapping[relID] if !exists { - return nil, nil + return nil } // log lsn and relation id for debugging @@ -733,42 +746,77 @@ func processUpdateMessage[Items model.Items]( rel, ok := p.relationMessageMapping[relID] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", relID) + return fmt.Errorf("unknown relation id: %d", relID) } - oldItems, _, err := processTuple(processor, p, msg.OldTuple, rel, p.tableNameMapping[tableName].Exclude) + oldItems, _, err := processTuple(rp.processor, p, msg.OldTuple, rel, p.TableNameMapping[tableName].Exclude) if err != nil { - return nil, fmt.Errorf("error converting old tuple to map: %w", err) + return fmt.Errorf("error converting old tuple to map: %w", err) } newItems, unchangedToastColumns, err := processTuple( - processor, p, msg.NewTuple, rel, p.tableNameMapping[tableName].Exclude) + rp.processor, p, msg.NewTuple, rel, p.TableNameMapping[tableName].Exclude) if err != nil { - return nil, fmt.Errorf("error converting new tuple to map: %w", err) + return fmt.Errorf("error converting new tuple to map: %w", err) } - return &model.UpdateRecord[Items]{ + rec := &model.UpdateRecord[Items]{ BaseRecord: p.baseRecord(lsn), OldItems: oldItems, NewItems: newItems, - DestinationTableName: p.tableNameMapping[tableName].Name, + DestinationTableName: p.TableNameMapping[tableName].Name, SourceTableName: tableName, UnchangedToastColumns: unchangedToastColumns, - }, nil + } + + // tableName here is destination tableName. + // should be ideally sourceTableName as we are in PullRecords. + // will change in future + isFullReplica := rp.pullRequest.TableNameSchemaMapping[tableName].IsReplicaIdentityFull + if isFullReplica { + err := rp.addRecordWithKey(ctx, p.logger, model.TableWithPkey{}, rec) + if err != nil { + return err + } + } else { + tablePkeyVal, err := model.RecToTablePKey(rp.pullRequest.TableNameSchemaMapping, rec) + if err != nil { + return err + } + + latestRecord, ok, err := rp.recordStore.Get(tablePkeyVal) + if err != nil { + return err + } + if !ok { + err = rp.addRecordWithKey(ctx, p.logger, tablePkeyVal, rec) + } else { + // iterate through unchanged toast cols and set them in new record + updatedCols := rec.NewItems.UpdateIfNotExists(latestRecord.GetItems()) + for _, col := range updatedCols { + delete(rec.UnchangedToastColumns, col) + } + err = rp.addRecordWithKey(ctx, p.logger, tablePkeyVal, rec) + } + if err != nil { + return err + } + } + + return nil } -// processDeleteMessage processes a delete message and returns a DeleteRecord -func processDeleteMessage[Items model.Items]( +func (rp *cdcRecordProcessor[Items]) processDeleteMessage( + ctx context.Context, p *PostgresCDCSource, lsn pglogrepl.LSN, msg *pglogrepl.DeleteMessage, - processor replProcessor[Items], -) (model.Record[Items], error) { +) error { relID := p.getParentRelIDIfPartitioned(msg.RelationID) - tableName, exists := p.srcTableIDNameMapping[relID] + tableName, exists := p.SrcTableIDNameMapping[relID] if !exists { - return nil, nil + return nil } // log lsn and relation id for debugging @@ -777,20 +825,56 @@ func processDeleteMessage[Items model.Items]( rel, ok := p.relationMessageMapping[relID] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", relID) + return fmt.Errorf("unknown relation id: %d", relID) } - items, _, err := processTuple(processor, p, msg.OldTuple, rel, p.tableNameMapping[tableName].Exclude) + items, _, err := processTuple(rp.processor, p, msg.OldTuple, rel, p.TableNameMapping[tableName].Exclude) if err != nil { - return nil, fmt.Errorf("error converting tuple to map: %w", err) + return fmt.Errorf("error converting tuple to map: %w", err) } - return &model.DeleteRecord[Items]{ + rec := &model.DeleteRecord[Items]{ BaseRecord: p.baseRecord(lsn), Items: items, - DestinationTableName: p.tableNameMapping[tableName].Name, + DestinationTableName: p.TableNameMapping[tableName].Name, SourceTableName: tableName, - }, nil + } + isFullReplica := rp.pullRequest.TableNameSchemaMapping[tableName].IsReplicaIdentityFull + if isFullReplica { + if err := rp.addRecordWithKey(ctx, p.logger, model.TableWithPkey{}, rec); err != nil { + return err + } + } else { + tablePkeyVal, err := model.RecToTablePKey(rp.pullRequest.TableNameSchemaMapping, rec) + if err != nil { + return err + } + + latestRecord, ok, err := rp.recordStore.Get(tablePkeyVal) + if err != nil { + return err + } + if ok { + rec.Items = latestRecord.GetItems() + if updateRecord, ok := latestRecord.(*model.UpdateRecord[Items]); ok { + rec.UnchangedToastColumns = updateRecord.UnchangedToastColumns + } + } else { + // there is nothing to backfill the items in the delete record with, + // so don't update the row with this record + // add sentinel value to prevent update statements from selecting + rec.UnchangedToastColumns = map[string]struct{}{ + "_peerdb_not_backfilled_delete": {}, + } + } + + // A delete can only be followed by an INSERT, which does not need backfilling + // No need to store DeleteRecords in memory or disk. + if err := rp.addRecordWithKey(ctx, p.logger, model.TableWithPkey{}, rec); err != nil { + return err + } + } + return nil } func auditSchemaDelta[Items model.Items](ctx context.Context, p *PostgresCDCSource, rec *model.RelationRecord[Items]) error { @@ -798,41 +882,52 @@ func auditSchemaDelta[Items model.Items](ctx context.Context, p *PostgresCDCSour workflowID := activityInfo.WorkflowExecution.ID runID := activityInfo.WorkflowExecution.RunID - _, err := p.catalogPool.Exec(ctx, - `INSERT INTO - peerdb_stats.schema_deltas_audit_log(flow_job_name,workflow_id,run_id,delta_info) + _, err := p.CatalogPool.Exec(ctx, + `INSERT INTO peerdb_stats.schema_deltas_audit_log(flow_job_name,workflow_id,run_id,delta_info) VALUES($1,$2,$3,$4)`, - p.flowJobName, workflowID, runID, rec) + p.FlowJobName, workflowID, runID, rec) if err != nil { return fmt.Errorf("failed to insert row into table: %w", err) } return nil } -// processRelationMessage processes a RelationMessage and returns a TableSchemaDelta -func processRelationMessage[Items model.Items]( +func (rp *cdcRecordProcessor[Items]) processRelationMessage( ctx context.Context, p *PostgresCDCSource, lsn pglogrepl.LSN, - currRel *pglogrepl.RelationMessage, -) (model.Record[Items], error) { + msg *pglogrepl.RelationMessage, +) error { + // treat all relation messages as corresponding to parent if partitioned. + msg.RelationID = p.getParentRelIDIfPartitioned(msg.RelationID) + + if _, exists := p.SrcTableIDNameMapping[msg.RelationID]; !exists { + return nil + } + + p.logger.Debug("RelationMessage", + slog.Any("RelationID", msg.RelationID), + slog.String("Namespace", msg.Namespace), + slog.String("RelationName", msg.RelationName), + slog.Any("Columns", msg.Columns)) + // not present in tables to sync, return immediately - if _, ok := p.srcTableIDNameMapping[currRel.RelationID]; !ok { + if _, ok := p.SrcTableIDNameMapping[msg.RelationID]; !ok { p.logger.Info("relid not present in srcTableIDNameMapping, skipping relation message", - slog.Uint64("relId", uint64(currRel.RelationID))) - return nil, nil + slog.Uint64("relId", uint64(msg.RelationID))) + return nil } // retrieve current TableSchema for table changed // tableNameSchemaMapping uses dst table name as the key, so annoying lookup - prevSchema := p.tableNameSchemaMapping[p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Name] + prevSchema := p.TableNameSchemaMapping[p.TableNameMapping[p.SrcTableIDNameMapping[msg.RelationID]].Name] // creating maps for lookup later prevRelMap := make(map[string]string) currRelMap := make(map[string]string) for _, column := range prevSchema.Columns { prevRelMap[column.Name] = column.Type } - for _, column := range currRel.Columns { + for _, column := range msg.Columns { switch prevSchema.System { case protos.TypeSystem_Q: qKind := p.postgresOIDToQValueKind(column.DataType) @@ -851,16 +946,16 @@ func processRelationMessage[Items model.Items]( } schemaDelta := &protos.TableSchemaDelta{ - SrcTableName: p.srcTableIDNameMapping[currRel.RelationID], - DstTableName: p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Name, + SrcTableName: p.SrcTableIDNameMapping[msg.RelationID], + DstTableName: p.TableNameMapping[p.SrcTableIDNameMapping[msg.RelationID]].Name, AddedColumns: nil, System: prevSchema.System, } - for _, column := range currRel.Columns { + for _, column := range msg.Columns { // not present in previous relation message, but in current one, so added. if _, ok := prevRelMap[column.Name]; !ok { // only add to delta if not excluded - if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { + if _, ok := p.TableNameMapping[p.SrcTableIDNameMapping[msg.RelationID]].Exclude[column.Name]; !ok { schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.FieldDescription{ Name: column.Name, Type: currRelMap[column.Name], @@ -883,23 +978,47 @@ func processRelationMessage[Items model.Items]( } } - p.relationMessageMapping[currRel.RelationID] = currRel + p.relationMessageMapping[msg.RelationID] = msg // only log audit if there is actionable delta if len(schemaDelta.AddedColumns) > 0 { rec := &model.RelationRecord[Items]{ BaseRecord: p.baseRecord(lsn), TableSchemaDelta: schemaDelta, } - return rec, auditSchemaDelta(ctx, p, rec) + if len(schemaDelta.AddedColumns) > 0 { + logger := logger.LoggerFromCtx(ctx) + logger.Info(fmt.Sprintf("Detected schema change for table %s, addedColumns: %v", + schemaDelta.SrcTableName, schemaDelta.AddedColumns)) + rp.records.AddSchemaDelta(rp.pullRequest.TableNameMapping, schemaDelta) + } + return auditSchemaDelta(ctx, p, rec) } - return nil, nil + return nil +} + +func (rp *cdcRecordProcessor[Items]) processLogicalDecodingMessage( + ctx context.Context, + p *PostgresCDCSource, + lsn pglogrepl.LSN, + msg *pglogrepl.LogicalDecodingMessage, +) error { + p.logger.Info("LogicalDecodingMessage", + slog.Bool("Transactional", msg.Transactional), + slog.String("Prefix", msg.Prefix), + slog.Int64("LSN", int64(msg.LSN))) + if !msg.Transactional { + rp.records.UpdateLatestCheckpoint(int64(msg.LSN)) + } + return rp.addRecordWithKey(ctx, p.logger, model.TableWithPkey{}, &model.MessageRecord[Items]{ + BaseRecord: p.baseRecord(lsn), + Prefix: msg.Prefix, + Content: string(msg.Content), + }) } func (p *PostgresCDCSource) getParentRelIDIfPartitioned(relID uint32) uint32 { - parentRelID, ok := p.childToParentRelIDMapping[relID] - if ok { + if parentRelID, ok := p.ChildToParentRelIDMap[relID]; ok { return parentRelID } - return relID } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 3e06f5367c..3907d49062 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -52,6 +52,7 @@ type ReplState struct { Publication string Offset int64 LastOffset atomic.Int64 + Version int32 } func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { @@ -151,21 +152,25 @@ func (c *PostgresConnector) ReplPing(ctx context.Context) error { func (c *PostgresConnector) MaybeStartReplication( ctx context.Context, slotName string, + version int32, publicationName string, lastOffset int64, ) error { if c.replState != nil && (c.replState.Offset != lastOffset || c.replState.Slot != slotName || + c.replState.Version != version || c.replState.Publication != publicationName) { - msg := fmt.Sprintf("replState changed, reset connector. slot name: old=%s new=%s, publication: old=%s new=%s, offset: old=%d new=%d", - c.replState.Slot, slotName, c.replState.Publication, publicationName, c.replState.Offset, lastOffset, + msg := fmt.Sprintf( + "replState changed, reset connector. "+ + "slot name: old=%s new=%s, version: old=%d new=%d, publication: old=%s new=%s, offset: old=%d new=%d", + c.replState.Slot, slotName, c.replState.Version, version, c.replState.Publication, publicationName, c.replState.Offset, lastOffset, ) c.logger.Info(msg) return temporal.NewNonRetryableApplicationError(msg, "desync", nil) } if c.replState == nil { - replicationOpts, err := c.replicationOptions(ctx, publicationName) + replicationOpts, err := c.replicationOptions(ctx, version, publicationName) if err != nil { return fmt.Errorf("error getting replication options: %w", err) } @@ -187,6 +192,7 @@ func (c *PostgresConnector) MaybeStartReplication( c.replState = &ReplState{ Slot: slotName, Publication: publicationName, + Version: version, Offset: lastOffset, LastOffset: atomic.Int64{}, } @@ -195,8 +201,20 @@ func (c *PostgresConnector) MaybeStartReplication( return nil } -func (c *PostgresConnector) replicationOptions(ctx context.Context, publicationName string) (pglogrepl.StartReplicationOptions, error) { - pluginArguments := append(make([]string, 0, 3), "proto_version '1'") +func (c *PostgresConnector) replicationOptions( + ctx context.Context, + version int32, + publicationName string, +) (pglogrepl.StartReplicationOptions, error) { + var pluginArguments []string + switch version { + case 1: + pluginArguments = append(make([]string, 0, 3), "proto_version '1'") + case 2: + pluginArguments = append(make([]string, 0, 4), "proto_version '2'", "streaming 'true'") + default: + return pglogrepl.StartReplicationOptions{}, errors.New("unsupported replication protocol version") + } if publicationName != "" { pubOpt := "publication_names " + QuoteLiteral(publicationName) @@ -341,6 +359,7 @@ func pullCore[Items model.Items]( slotName = req.OverrideReplicationSlotName } + version := int32(2) publicationName := c.getDefaultPublicationName(req.FlowJobName) if req.OverridePublicationName != "" { publicationName = req.OverridePublicationName @@ -369,7 +388,10 @@ func pullCore[Items model.Items]( return fmt.Errorf("error getting child to parent relid map: %w", err) } - if err := c.MaybeStartReplication(ctx, slotName, publicationName, req.LastOffset); err != nil { + c.replLock.Lock() + defer c.replLock.Unlock() + + if err := c.MaybeStartReplication(ctx, slotName, version, publicationName, req.LastOffset); err != nil { c.logger.Error("error starting replication", slog.Any("error", err)) return err } @@ -384,6 +406,7 @@ func pullCore[Items model.Items]( CatalogPool: catalogPool, FlowJobName: req.FlowJobName, RelationMessageMapping: c.relationMessageMapping, + Version: version, }) if err := PullCdcRecords(ctx, cdc, req, processor, &c.replLock); err != nil { diff --git a/flow/connectors/utils/cdc_store.go b/flow/connectors/utils/cdc_store.go index 6b36f73258..682ff05169 100644 --- a/flow/connectors/utils/cdc_store.go +++ b/flow/connectors/utils/cdc_store.go @@ -32,7 +32,7 @@ func encVal(val any) ([]byte, error) { return buf.Bytes(), nil } -type cdcStore[Items model.Items] struct { +type CdcStore[Items model.Items] struct { inMemoryRecords map[model.TableWithPkey]model.Record[Items] pebbleDB *pebble.DB flowJobName string @@ -44,7 +44,7 @@ type cdcStore[Items model.Items] struct { numRecordsSwitchThreshold int } -func NewCDCStore[Items model.Items](ctx context.Context, flowJobName string) (*cdcStore[Items], error) { +func NewCDCStore[Items model.Items](ctx context.Context, flowJobName string) (*CdcStore[Items], error) { numRecordsSwitchThreshold, err := peerdbenv.PeerDBCDCDiskSpillRecordsThreshold(ctx) if err != nil { return nil, fmt.Errorf("failed to get CDC disk spill records threshold: %w", err) @@ -54,7 +54,7 @@ func NewCDCStore[Items model.Items](ctx context.Context, flowJobName string) (*c return nil, fmt.Errorf("failed to get CDC disk spill memory percent threshold: %w", err) } - return &cdcStore[Items]{ + return &CdcStore[Items]{ inMemoryRecords: make(map[model.TableWithPkey]model.Record[Items]), pebbleDB: nil, numRecords: atomic.Int32{}, @@ -117,7 +117,7 @@ func init() { gob.Register(qvalue.QValueArrayBoolean{}) } -func (c *cdcStore[T]) initPebbleDB() error { +func (c *CdcStore[T]) initPebbleDB() error { if c.pebbleDB != nil { return nil } @@ -141,7 +141,7 @@ func (c *cdcStore[T]) initPebbleDB() error { return nil } -func (c *cdcStore[T]) diskSpillThresholdsExceeded() bool { +func (c *CdcStore[T]) diskSpillThresholdsExceeded() bool { if len(c.inMemoryRecords) >= c.numRecordsSwitchThreshold { c.thresholdReason = fmt.Sprintf("more than %d primary keys read, spilling to disk", c.numRecordsSwitchThreshold) @@ -159,7 +159,7 @@ func (c *cdcStore[T]) diskSpillThresholdsExceeded() bool { return false } -func (c *cdcStore[T]) Set(logger log.Logger, key model.TableWithPkey, rec model.Record[T]) error { +func (c *CdcStore[T]) Set(logger log.Logger, key model.TableWithPkey, rec model.Record[T]) error { if key.TableName != "" { _, ok := c.inMemoryRecords[key] if ok || !c.diskSpillThresholdsExceeded() { @@ -168,8 +168,7 @@ func (c *cdcStore[T]) Set(logger log.Logger, key model.TableWithPkey, rec model. if c.pebbleDB == nil { logger.Info(c.thresholdReason, slog.String(string(shared.FlowNameKey), c.flowJobName)) - err := c.initPebbleDB() - if err != nil { + if err := c.initPebbleDB(); err != nil { return err } } @@ -199,7 +198,7 @@ func (c *cdcStore[T]) Set(logger log.Logger, key model.TableWithPkey, rec model. } // bool is to indicate if a record is found or not [similar to ok] -func (c *cdcStore[T]) Get(key model.TableWithPkey) (model.Record[T], bool, error) { +func (c *CdcStore[T]) Get(key model.TableWithPkey) (model.Record[T], bool, error) { rec, ok := c.inMemoryRecords[key] if ok { return rec, true, nil @@ -227,8 +226,7 @@ func (c *cdcStore[T]) Get(key model.TableWithPkey) (model.Record[T], bool, error dec := gob.NewDecoder(bytes.NewReader(encodedRec)) var rec model.Record[T] - err = dec.Decode(&rec) - if err != nil { + if err := dec.Decode(&rec); err != nil { return nil, false, fmt.Errorf("failed to decode record: %w", err) } @@ -237,15 +235,15 @@ func (c *cdcStore[T]) Get(key model.TableWithPkey) (model.Record[T], bool, error return nil, false, nil } -func (c *cdcStore[T]) Len() int { +func (c *CdcStore[T]) Len() int { return int(c.numRecords.Load()) } -func (c *cdcStore[T]) IsEmpty() bool { +func (c *CdcStore[T]) IsEmpty() bool { return c.Len() == 0 } -func (c *cdcStore[T]) Close() error { +func (c *CdcStore[T]) Close() error { c.inMemoryRecords = nil if c.pebbleDB != nil { err := c.pebbleDB.Close() diff --git a/nexus/catalog/migrations/V29__v2.sql b/nexus/catalog/migrations/V29__v2.sql new file mode 100644 index 0000000000..4dbd34b173 --- /dev/null +++ b/nexus/catalog/migrations/V29__v2.sql @@ -0,0 +1,7 @@ +create table v2cdc ( + flow_name text, + xid xid, + lsn pg_lsn, + stream bytea[], + primary key (flow_name, xid, lsn) +);