Skip to content

Commit

Permalink
move syncing records into one long running activity,
Browse files Browse the repository at this point in the history
allowing frequent sync batch frequency without need to reconect
  • Loading branch information
serprex committed Dec 23, 2024
1 parent 6d49e2f commit a16e168
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 291 deletions.
166 changes: 97 additions & 69 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
Expand All @@ -16,7 +14,7 @@ import (
"go.opentelemetry.io/otel/metric"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
"go.temporal.io/sdk/temporal"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"github.com/PeerDB-io/peer-flow/alerting"
Expand All @@ -43,19 +41,17 @@ type NormalizeBatchRequest struct {
BatchID int64
}

type CdcCacheEntry struct {
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
normalizeDone chan struct{}
type CdcState struct {
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
errGroup *errgroup.Group
}

type FlowableActivity struct {
CatalogPool *pgxpool.Pool
Alerter *alerting.Alerter
CdcCache map[string]CdcCacheEntry
OtelManager *otel_metrics.OtelManager
CdcCacheRw sync.RWMutex
}

func (a *FlowableActivity) CheckConnection(
Expand Down Expand Up @@ -253,91 +249,125 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

func (a *FlowableActivity) MaintainPull(
func (a *FlowableActivity) maintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
sessionID string,
) error {
) (CdcState, context.Context, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, config.Env, a.CatalogPool, config.SourceName)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
return CdcState{}, nil, err
}
defer connectors.CloseConnector(ctx, srcConn)

if err := srcConn.SetupReplConn(ctx); err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, nil, err
}

normalizeBufferSize, err := peerdbenv.PeerDBNormalizeChannelBufferSize(ctx, config.Env)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, nil, err
}

// syncDone will be closed by UnmaintainPull,
// whereas normalizeDone will be closed by the normalize goroutine
// syncDone will be closed by SyncFlow,
// whereas normalizeDone will be closed by normalizing goroutine
// Wait on normalizeDone at end to not interrupt final normalize
syncDone := make(chan struct{})
normalize := make(chan NormalizeBatchRequest, normalizeBufferSize)
normalizeDone := make(chan struct{})
a.CdcCacheRw.Lock()
a.CdcCache[sessionID] = CdcCacheEntry{
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
normalizeDone: normalizeDone,
}
a.CdcCacheRw.Unlock()

ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

go a.normalizeLoop(ctx, config, syncDone, normalize, normalizeDone)

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", err)

group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
// returning error signals sync to stop, normalize can recover connections without interrupting sync, so never return error
a.normalizeLoop(groupCtx, config, syncDone, normalize)
return nil
})
group.Go(func() error {
defer connectors.CloseConnector(groupCtx, srcConn)
if err := a.maintainReplConn(groupCtx, config.FlowJobName, srcConn, syncDone); err != nil {
a.Alerter.LogFlowError(groupCtx, config.FlowJobName, err)
return err
}
return nil
})

return CdcState{
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
errGroup: group,
}, groupCtx, nil
}

func (a *FlowableActivity) SyncFlow(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

cdcState, groupCtx, err := a.maintainPull(ctx, config)
if err != nil {
logger.Error("MaintainPull failed", slog.Any("error", err))
return err
}

currentSyncFlowNum := int32(0)
totalRecordsSynced := int64(0)

for groupCtx.Err() == nil {
currentSyncFlowNum += 1
logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum)))

var numRecordsSynced int64
var syncErr error
if config.System == protos.TypeSystem_Q {
numRecordsSynced, syncErr = a.SyncRecords(groupCtx, config, options, cdcState)
} else {
numRecordsSynced, syncErr = a.SyncPg(groupCtx, config, options, cdcState)
}

if syncErr != nil {
if groupCtx.Err() != nil {
// need to return ctx.Err(), avoid returning syncErr that's wrapped context canceled
break
}
logger.Error("failed to sync records", slog.Any("error", syncErr))
close(cdcState.syncDone)
return syncErr
} else {
totalRecordsSynced += numRecordsSynced
logger.Info("Total records synced",
slog.Int64("numRecordsSynced", numRecordsSynced), slog.Int64("totalRecordsSynced", totalRecordsSynced))

if options.NumberOfSyncs > 0 && currentSyncFlowNum >= options.NumberOfSyncs {
break
}
case <-syncDone:
return nil
case <-ctx.Done():
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
return nil
}
}
}

func (a *FlowableActivity) UnmaintainPull(ctx context.Context, sessionID string) error {
var normalizeDone chan struct{}
a.CdcCacheRw.Lock()
if entry, ok := a.CdcCache[sessionID]; ok {
close(entry.syncDone)
delete(a.CdcCache, sessionID)
normalizeDone = entry.normalizeDone
close(cdcState.syncDone)
waitErr := cdcState.errGroup.Wait()
if err := ctx.Err(); err != nil {
logger.Info("sync canceled", slog.Any("error", err))
return err
} else if waitErr != nil {
logger.Error("sync failed", slog.Any("error", waitErr))
return waitErr
}
a.CdcCacheRw.Unlock()
<-normalizeDone
return nil
}

func (a *FlowableActivity) SyncRecords(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (model.SyncRecordsResult, error) {
cdcState CdcState,
) (int64, error) {
var adaptStream func(stream *model.CDCStream[model.RecordItems]) (*model.CDCStream[model.RecordItems], error)
if config.Script != "" {
var onErr context.CancelCauseFunc
Expand Down Expand Up @@ -368,22 +398,20 @@ func (a *FlowableActivity) SyncRecords(
return stream, nil
}
}
numRecords, err := syncCore(ctx, a, config, options, sessionID, adaptStream,
return syncCore(ctx, a, config, options, cdcState, adaptStream,
connectors.CDCPullConnector.PullRecords,
connectors.CDCSyncConnector.SyncRecords)
return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err
}

func (a *FlowableActivity) SyncPg(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (model.SyncRecordsResult, error) {
numRecords, err := syncCore(ctx, a, config, options, sessionID, nil,
cdcState CdcState,
) (int64, error) {
return syncCore(ctx, a, config, options, cdcState, nil,
connectors.CDCPullPgConnector.PullPg,
connectors.CDCSyncPgConnector.SyncPg)
return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err
}

func (a *FlowableActivity) StartNormalize(
Expand Down
70 changes: 25 additions & 45 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"log/slog"
"reflect"
"sync/atomic"
"time"

Expand Down Expand Up @@ -50,43 +49,6 @@ func heartbeatRoutine(
)
}

func waitForCdcCache[TPull connectors.CDCPullConnectorCore](
ctx context.Context, a *FlowableActivity, sessionID string,
) (TPull, chan NormalizeBatchRequest, error) {
var none TPull
logger := activity.GetLogger(ctx)
attempt := 0
waitInterval := time.Second
// try for 5 minutes, once per second
// after that, try indefinitely every minute
for {
a.CdcCacheRw.RLock()
entry, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if ok {
if conn, ok := entry.connector.(TPull); ok {
return conn, entry.normalize, nil
}
return none, nil, fmt.Errorf("expected %s, cache held %T", reflect.TypeFor[TPull]().Name(), entry.connector)
}
activity.RecordHeartbeat(ctx, fmt.Sprintf("wait %s for source connector", waitInterval))
attempt += 1
if attempt > 2 {
logger.Info("waiting on source connector setup",
slog.Int("attempt", attempt), slog.String("sessionID", sessionID))
}
if err := ctx.Err(); err != nil {
return none, nil, err
}
time.Sleep(waitInterval)
if attempt == 300 {
logger.Info("source connector not setup in time, transition to slow wait",
slog.String("sessionID", sessionID))
waitInterval = time.Minute
}
}
}

func (a *FlowableActivity) getTableNameSchemaMapping(ctx context.Context, flowName string) (map[string]*protos.TableSchema, error) {
rows, err := a.CatalogPool.Query(ctx, "select table_name, table_schema from table_schema_mapping where flow_name = $1", flowName)
if err != nil {
Expand Down Expand Up @@ -142,7 +104,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
a *FlowableActivity,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
cdcState CdcState,
adaptStream func(*model.CDCStream[Items]) (*model.CDCStream[Items], error),
pull func(TPull, context.Context, *pgxpool.Pool, *otel_metrics.OtelManager, *model.PullRecordsRequest[Items]) error,
sync func(TSync, context.Context, *model.SyncRecordsRequest[Items]) (*model.SyncResponse, error),
Expand All @@ -160,10 +122,8 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
}

srcConn, normChan, err := waitForCdcCache[TPull](ctx, a, sessionID)
if err != nil {
return 0, err
}
srcConn := cdcState.connector.(TPull)
normChan := cdcState.normalize
if err := srcConn.ConnectionActive(ctx); err != nil {
return 0, temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", nil)
}
Expand Down Expand Up @@ -630,15 +590,35 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn
return currentSnapshotXmin, nil
}

func (a *FlowableActivity) maintainReplConn(
ctx context.Context, flowName string, srcConn connectors.CDCPullConnector, syncDone <-chan struct{},
) error {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return fmt.Errorf("connection to source down: %w", err)
}
case <-syncDone:
return nil
case <-ctx.Done():
return nil
}
}
}

// Suitable to be run as goroutine
func (a *FlowableActivity) normalizeLoop(
ctx context.Context,
config *protos.FlowConnectionConfigs,
syncDone <-chan struct{},
normalize <-chan NormalizeBatchRequest,
normalizeDone chan struct{},
) {
defer close(normalizeDone)
logger := activity.GetLogger(ctx)

for {
Expand Down
1 change: 0 additions & 1 deletion flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ func WorkerSetup(opts *WorkerSetupOptions) (*workerSetupResponse, error) {
w.RegisterActivity(&activities.FlowableActivity{
CatalogPool: conn,
Alerter: alerting.NewAlerter(context.Background(), conn),
CdcCache: make(map[string]activities.CdcCacheEntry),
OtelManager: otelManager,
})

Expand Down
4 changes: 0 additions & 4 deletions flow/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ type SyncResponse struct {
CurrentSyncBatchID int64
}

type SyncRecordsResult struct {
NumRecordsSynced int64
}

type NormalizeResponse struct {
StartBatchID int64
EndBatchID int64
Expand Down
4 changes: 0 additions & 4 deletions flow/model/signals.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,3 @@ var FlowSignal = TypedSignal[CDCFlowSignal]{
var CDCDynamicPropertiesSignal = TypedSignal[*protos.CDCFlowConfigUpdate]{
Name: "cdc-dynamic-properties",
}

var SyncStopSignal = TypedSignal[struct{}]{
Name: "sync-stop",
}
Loading

0 comments on commit a16e168

Please sign in to comment.