diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 84a9234bb2..6c99b4f33e 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -5,9 +5,7 @@ import ( "errors" "fmt" "log/slog" - "sync" "sync/atomic" - "time" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" @@ -16,7 +14,7 @@ import ( "go.opentelemetry.io/otel/metric" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/log" - "go.temporal.io/sdk/temporal" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "github.com/PeerDB-io/peer-flow/alerting" @@ -43,19 +41,17 @@ type NormalizeBatchRequest struct { BatchID int64 } -type CdcCacheEntry struct { - connector connectors.CDCPullConnectorCore - syncDone chan struct{} - normalize chan NormalizeBatchRequest - normalizeDone chan struct{} +type CdcState struct { + connector connectors.CDCPullConnectorCore + syncDone chan struct{} + normalize chan NormalizeBatchRequest + errGroup *errgroup.Group } type FlowableActivity struct { CatalogPool *pgxpool.Pool Alerter *alerting.Alerter - CdcCache map[string]CdcCacheEntry OtelManager *otel_metrics.OtelManager - CdcCacheRw sync.RWMutex } func (a *FlowableActivity) CheckConnection( @@ -253,82 +249,116 @@ func (a *FlowableActivity) CreateNormalizedTable( }, nil } -func (a *FlowableActivity) MaintainPull( +func (a *FlowableActivity) maintainPull( ctx context.Context, config *protos.FlowConnectionConfigs, - sessionID string, -) error { +) (CdcState, context.Context, error) { ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) srcConn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, config.Env, a.CatalogPool, config.SourceName) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return err + return CdcState{}, nil, err } - defer connectors.CloseConnector(ctx, srcConn) if err := srcConn.SetupReplConn(ctx); err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return err + connectors.CloseConnector(ctx, srcConn) + return CdcState{}, nil, err } normalizeBufferSize, err := peerdbenv.PeerDBNormalizeChannelBufferSize(ctx, config.Env) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return err + connectors.CloseConnector(ctx, srcConn) + return CdcState{}, nil, err } - // syncDone will be closed by UnmaintainPull, - // whereas normalizeDone will be closed by the normalize goroutine + // syncDone will be closed by SyncFlow, + // whereas normalizeDone will be closed by normalizing goroutine // Wait on normalizeDone at end to not interrupt final normalize syncDone := make(chan struct{}) normalize := make(chan NormalizeBatchRequest, normalizeBufferSize) - normalizeDone := make(chan struct{}) - a.CdcCacheRw.Lock() - a.CdcCache[sessionID] = CdcCacheEntry{ - connector: srcConn, - syncDone: syncDone, - normalize: normalize, - normalizeDone: normalizeDone, - } - a.CdcCacheRw.Unlock() - - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - - go a.normalizeLoop(ctx, config, syncDone, normalize, normalizeDone) - - for { - select { - case <-ticker.C: - activity.RecordHeartbeat(ctx, "keep session alive") - if err := srcConn.ReplPing(ctx); err != nil { - a.CdcCacheRw.Lock() - delete(a.CdcCache, sessionID) - a.CdcCacheRw.Unlock() - a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", err) + + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + // returning error signals sync to stop, normalize can recover connections without interrupting sync, so never return error + a.normalizeLoop(groupCtx, config, syncDone, normalize) + return nil + }) + group.Go(func() error { + defer connectors.CloseConnector(groupCtx, srcConn) + if err := a.maintainReplConn(groupCtx, config.FlowJobName, srcConn, syncDone); err != nil { + a.Alerter.LogFlowError(groupCtx, config.FlowJobName, err) + return err + } + return nil + }) + + return CdcState{ + connector: srcConn, + syncDone: syncDone, + normalize: normalize, + errGroup: group, + }, groupCtx, nil +} + +func (a *FlowableActivity) SyncFlow( + ctx context.Context, + config *protos.FlowConnectionConfigs, + options *protos.SyncFlowOptions, +) error { + ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) + logger := activity.GetLogger(ctx) + + cdcState, groupCtx, err := a.maintainPull(ctx, config) + if err != nil { + logger.Error("MaintainPull failed", slog.Any("error", err)) + return err + } + + currentSyncFlowNum := int32(0) + totalRecordsSynced := int64(0) + + for groupCtx.Err() == nil { + currentSyncFlowNum += 1 + logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum))) + + var numRecordsSynced int64 + var syncErr error + if config.System == protos.TypeSystem_Q { + numRecordsSynced, syncErr = a.SyncRecords(groupCtx, config, options, cdcState) + } else { + numRecordsSynced, syncErr = a.SyncPg(groupCtx, config, options, cdcState) + } + + if syncErr != nil { + if groupCtx.Err() != nil { + // need to return ctx.Err(), avoid returning syncErr that's wrapped context canceled + break + } + logger.Error("failed to sync records", slog.Any("error", syncErr)) + close(cdcState.syncDone) + return syncErr + } else { + totalRecordsSynced += numRecordsSynced + logger.Info("Total records synced", + slog.Int64("numRecordsSynced", numRecordsSynced), slog.Int64("totalRecordsSynced", totalRecordsSynced)) + + if options.NumberOfSyncs > 0 && currentSyncFlowNum >= options.NumberOfSyncs { + break } - case <-syncDone: - return nil - case <-ctx.Done(): - a.CdcCacheRw.Lock() - delete(a.CdcCache, sessionID) - a.CdcCacheRw.Unlock() - return nil } } -} -func (a *FlowableActivity) UnmaintainPull(ctx context.Context, sessionID string) error { - var normalizeDone chan struct{} - a.CdcCacheRw.Lock() - if entry, ok := a.CdcCache[sessionID]; ok { - close(entry.syncDone) - delete(a.CdcCache, sessionID) - normalizeDone = entry.normalizeDone + close(cdcState.syncDone) + waitErr := cdcState.errGroup.Wait() + if err := ctx.Err(); err != nil { + logger.Info("sync canceled", slog.Any("error", err)) + return err + } else if waitErr != nil { + logger.Error("sync failed", slog.Any("error", waitErr)) + return waitErr } - a.CdcCacheRw.Unlock() - <-normalizeDone return nil } @@ -336,8 +366,8 @@ func (a *FlowableActivity) SyncRecords( ctx context.Context, config *protos.FlowConnectionConfigs, options *protos.SyncFlowOptions, - sessionID string, -) (model.SyncRecordsResult, error) { + cdcState CdcState, +) (int64, error) { var adaptStream func(stream *model.CDCStream[model.RecordItems]) (*model.CDCStream[model.RecordItems], error) if config.Script != "" { var onErr context.CancelCauseFunc @@ -368,22 +398,20 @@ func (a *FlowableActivity) SyncRecords( return stream, nil } } - numRecords, err := syncCore(ctx, a, config, options, sessionID, adaptStream, + return syncCore(ctx, a, config, options, cdcState, adaptStream, connectors.CDCPullConnector.PullRecords, connectors.CDCSyncConnector.SyncRecords) - return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err } func (a *FlowableActivity) SyncPg( ctx context.Context, config *protos.FlowConnectionConfigs, options *protos.SyncFlowOptions, - sessionID string, -) (model.SyncRecordsResult, error) { - numRecords, err := syncCore(ctx, a, config, options, sessionID, nil, + cdcState CdcState, +) (int64, error) { + return syncCore(ctx, a, config, options, cdcState, nil, connectors.CDCPullPgConnector.PullPg, connectors.CDCSyncPgConnector.SyncPg) - return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err } func (a *FlowableActivity) StartNormalize( diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index b9bad18036..b3717c5559 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "log/slog" - "reflect" "sync/atomic" "time" @@ -50,43 +49,6 @@ func heartbeatRoutine( ) } -func waitForCdcCache[TPull connectors.CDCPullConnectorCore]( - ctx context.Context, a *FlowableActivity, sessionID string, -) (TPull, chan NormalizeBatchRequest, error) { - var none TPull - logger := activity.GetLogger(ctx) - attempt := 0 - waitInterval := time.Second - // try for 5 minutes, once per second - // after that, try indefinitely every minute - for { - a.CdcCacheRw.RLock() - entry, ok := a.CdcCache[sessionID] - a.CdcCacheRw.RUnlock() - if ok { - if conn, ok := entry.connector.(TPull); ok { - return conn, entry.normalize, nil - } - return none, nil, fmt.Errorf("expected %s, cache held %T", reflect.TypeFor[TPull]().Name(), entry.connector) - } - activity.RecordHeartbeat(ctx, fmt.Sprintf("wait %s for source connector", waitInterval)) - attempt += 1 - if attempt > 2 { - logger.Info("waiting on source connector setup", - slog.Int("attempt", attempt), slog.String("sessionID", sessionID)) - } - if err := ctx.Err(); err != nil { - return none, nil, err - } - time.Sleep(waitInterval) - if attempt == 300 { - logger.Info("source connector not setup in time, transition to slow wait", - slog.String("sessionID", sessionID)) - waitInterval = time.Minute - } - } -} - func (a *FlowableActivity) getTableNameSchemaMapping(ctx context.Context, flowName string) (map[string]*protos.TableSchema, error) { rows, err := a.CatalogPool.Query(ctx, "select table_name, table_schema from table_schema_mapping where flow_name = $1", flowName) if err != nil { @@ -142,7 +104,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon a *FlowableActivity, config *protos.FlowConnectionConfigs, options *protos.SyncFlowOptions, - sessionID string, + cdcState CdcState, adaptStream func(*model.CDCStream[Items]) (*model.CDCStream[Items], error), pull func(TPull, context.Context, *pgxpool.Pool, *otel_metrics.OtelManager, *model.PullRecordsRequest[Items]) error, sync func(TSync, context.Context, *model.SyncRecordsRequest[Items]) (*model.SyncResponse, error), @@ -160,10 +122,8 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude) } - srcConn, normChan, err := waitForCdcCache[TPull](ctx, a, sessionID) - if err != nil { - return 0, err - } + srcConn := cdcState.connector.(TPull) + normChan := cdcState.normalize if err := srcConn.ConnectionActive(ctx); err != nil { return 0, temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", nil) } @@ -630,15 +590,35 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn return currentSnapshotXmin, nil } +func (a *FlowableActivity) maintainReplConn( + ctx context.Context, flowName string, srcConn connectors.CDCPullConnector, syncDone <-chan struct{}, +) error { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + activity.RecordHeartbeat(ctx, "keep session alive") + if err := srcConn.ReplPing(ctx); err != nil { + a.Alerter.LogFlowError(ctx, flowName, err) + return fmt.Errorf("connection to source down: %w", err) + } + case <-syncDone: + return nil + case <-ctx.Done(): + return nil + } + } +} + // Suitable to be run as goroutine func (a *FlowableActivity) normalizeLoop( ctx context.Context, config *protos.FlowConnectionConfigs, syncDone <-chan struct{}, normalize <-chan NormalizeBatchRequest, - normalizeDone chan struct{}, ) { - defer close(normalizeDone) logger := activity.GetLogger(ctx) for { diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index 87fbd0aa54..cfb4f0e049 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -167,7 +167,6 @@ func WorkerSetup(opts *WorkerSetupOptions) (*workerSetupResponse, error) { w.RegisterActivity(&activities.FlowableActivity{ CatalogPool: conn, Alerter: alerting.NewAlerter(context.Background(), conn), - CdcCache: make(map[string]activities.CdcCacheEntry), OtelManager: otelManager, }) diff --git a/flow/model/model.go b/flow/model/model.go index 63e0922cd2..5a9d7b76cc 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -169,10 +169,6 @@ type SyncResponse struct { CurrentSyncBatchID int64 } -type SyncRecordsResult struct { - NumRecordsSynced int64 -} - type NormalizeResponse struct { StartBatchID int64 EndBatchID int64 diff --git a/flow/model/signals.go b/flow/model/signals.go index 31a89ff778..17ee7d102a 100644 --- a/flow/model/signals.go +++ b/flow/model/signals.go @@ -134,7 +134,3 @@ var FlowSignal = TypedSignal[CDCFlowSignal]{ var CDCDynamicPropertiesSignal = TypedSignal[*protos.CDCFlowConfigUpdate]{ Name: "cdc-dynamic-properties", } - -var SyncStopSignal = TypedSignal[struct{}]{ - Name: "sync-stop", -} diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index c43b56a9c6..adcc82835a 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -473,21 +473,16 @@ func CDCFlowWorkflow( state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING } - syncFlowID := GetChildWorkflowID("sync-flow", cfg.FlowJobName, originalRunID) - - var restart, finished bool - syncFlowOpts := workflow.ChildWorkflowOptions{ - WorkflowID: syncFlowID, - ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + var finished bool + syncCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 365 * 24 * time.Hour, + HeartbeatTimeout: time.Minute, + WaitForCancellation: true, RetryPolicy: &temporal.RetryPolicy{ - MaximumAttempts: 20, + InitialInterval: 30 * time.Second, }, - TypedSearchAttributes: mirrorNameSearch, - WaitForCancellation: true, - } - syncCtx := workflow.WithChildOptions(ctx, syncFlowOpts) - - syncFlowFuture := workflow.ExecuteChildWorkflow(syncCtx, SyncFlowWorkflow, cfg, state.SyncFlowOptions) + }) + syncFlowFuture := workflow.ExecuteActivity(syncCtx, flowable.SyncFlow, cfg, state.SyncFlowOptions) mainLoopSelector := workflow.NewNamedSelector(ctx, "MainLoop") mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { @@ -509,7 +504,6 @@ func CDCFlowWorkflow( logger.Info("sync finished") } syncFlowFuture = nil - restart = true finished = true if state.SyncFlowOptions.NumberOfSyncs > 0 { state.ActiveSignal = model.PauseSignal @@ -537,16 +531,10 @@ func CDCFlowWorkflow( } if shared.ShouldWorkflowContinueAsNew(ctx) { - restart = true - if syncFlowFuture != nil { - if err := model.SyncStopSignal.SignalChildWorkflow(ctx, syncFlowFuture, struct{}{}).Get(ctx, nil); err != nil { - logger.Warn("failed to send sync-stop, finishing", slog.Any("error", err)) - finished = true - } - } + finished = true } - if restart || finished { + if finished { for ctx.Err() == nil && (!finished || mainLoopSelector.HasPending()) { mainLoopSelector.Select(ctx) } diff --git a/flow/workflows/register.go b/flow/workflows/register.go index 4c458319cc..594fc989cd 100644 --- a/flow/workflows/register.go +++ b/flow/workflows/register.go @@ -8,7 +8,6 @@ func RegisterFlowWorkerWorkflows(w worker.WorkflowRegistry) { w.RegisterWorkflow(CDCFlowWorkflow) w.RegisterWorkflow(DropFlowWorkflow) w.RegisterWorkflow(SetupFlowWorkflow) - w.RegisterWorkflow(SyncFlowWorkflow) w.RegisterWorkflow(QRepFlowWorkflow) w.RegisterWorkflow(QRepWaitForNewRowsWorkflow) w.RegisterWorkflow(QRepPartitionWorkflow) diff --git a/flow/workflows/sync_flow.go b/flow/workflows/sync_flow.go deleted file mode 100644 index 788d34cd2b..0000000000 --- a/flow/workflows/sync_flow.go +++ /dev/null @@ -1,145 +0,0 @@ -package peerflow - -import ( - "log/slog" - "time" - - "go.temporal.io/sdk/log" - "go.temporal.io/sdk/temporal" - "go.temporal.io/sdk/workflow" - - "github.com/PeerDB-io/peer-flow/generated/protos" - "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/shared" -) - -func SyncFlowWorkflow( - ctx workflow.Context, - config *protos.FlowConnectionConfigs, - options *protos.SyncFlowOptions, -) error { - logger := log.With(workflow.GetLogger(ctx), slog.String(string(shared.FlowNameKey), config.FlowJobName)) - - sessionOptions := &workflow.SessionOptions{ - CreationTimeout: 5 * time.Minute, - ExecutionTimeout: 14 * 24 * time.Hour, - HeartbeatTimeout: time.Minute, - } - - syncSessionCtx, err := workflow.CreateSession(ctx, sessionOptions) - if err != nil { - return err - } - defer workflow.CompleteSession(syncSessionCtx) - - sessionID := workflow.GetSessionInfo(syncSessionCtx).SessionID - maintainCtx := workflow.WithActivityOptions(syncSessionCtx, workflow.ActivityOptions{ - StartToCloseTimeout: 30 * 24 * time.Hour, - HeartbeatTimeout: time.Hour, - WaitForCancellation: true, - RetryPolicy: &temporal.RetryPolicy{MaximumAttempts: 1}, - }) - fMaintain := workflow.ExecuteActivity( - maintainCtx, - flowable.MaintainPull, - config, - sessionID, - ) - - var stop, syncErr bool - currentSyncFlowNum := int32(0) - totalRecordsSynced := int64(0) - - selector := workflow.NewNamedSelector(ctx, "SyncLoop") - selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) {}) - selector.AddFuture(fMaintain, func(f workflow.Future) { - if err := f.Get(ctx, nil); err != nil { - logger.Error("MaintainPull failed", slog.Any("error", err)) - syncErr = true - } - }) - - stopChan := model.SyncStopSignal.GetSignalChannel(ctx) - stopChan.AddToSelector(selector, func(_ struct{}, _ bool) { - stop = true - }) - - syncFlowCtx := workflow.WithActivityOptions(syncSessionCtx, workflow.ActivityOptions{ - StartToCloseTimeout: 7 * 24 * time.Hour, - HeartbeatTimeout: time.Minute, - WaitForCancellation: true, - RetryPolicy: &temporal.RetryPolicy{ - InitialInterval: 30 * time.Second, - }, - }) - for !stop && ctx.Err() == nil { - var syncDone bool - - currentSyncFlowNum += 1 - logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum))) - - var syncFlowFuture workflow.Future - if config.System == protos.TypeSystem_Q { - syncFlowFuture = workflow.ExecuteActivity(syncFlowCtx, flowable.SyncRecords, config, options, sessionID) - } else { - syncFlowFuture = workflow.ExecuteActivity(syncFlowCtx, flowable.SyncPg, config, options, sessionID) - } - selector.AddFuture(syncFlowFuture, func(f workflow.Future) { - syncDone = true - - var syncResult model.SyncRecordsResult - if err := f.Get(ctx, &syncResult); err != nil { - logger.Error("failed to execute sync flow", slog.Any("error", err)) - syncErr = true - } else { - totalRecordsSynced += syncResult.NumRecordsSynced - logger.Info("Total records synced", - slog.Int64("numRecordsSynced", syncResult.NumRecordsSynced), slog.Int64("totalRecordsSynced", totalRecordsSynced)) - } - }) - - for ctx.Err() == nil && ((!syncDone && !syncErr) || selector.HasPending()) { - selector.Select(ctx) - } - - if syncErr { - logger.Info("sync flow error, sleeping for 30 seconds...") - err := workflow.Sleep(ctx, 30*time.Second) - if err != nil { - logger.Error("failed to sleep", slog.Any("error", err)) - } - } - - if (options.NumberOfSyncs > 0 && currentSyncFlowNum >= options.NumberOfSyncs) || - syncErr || ctx.Err() != nil || shared.ShouldWorkflowContinueAsNew(ctx) { - break - } - } - - if err := ctx.Err(); err != nil { - logger.Info("sync canceled", slog.Any("error", err)) - return err - } - - unmaintainCtx := workflow.WithActivityOptions(syncSessionCtx, workflow.ActivityOptions{ - RetryPolicy: &temporal.RetryPolicy{MaximumAttempts: 1}, - StartToCloseTimeout: time.Minute, - HeartbeatTimeout: time.Minute, - WaitForCancellation: true, - }) - if err := workflow.ExecuteActivity( - unmaintainCtx, - flowable.UnmaintainPull, - sessionID, - ).Get(unmaintainCtx, nil); err != nil { - logger.Warn("UnmaintainPull failed", slog.Any("error", err)) - } - - if stop || currentSyncFlowNum >= options.NumberOfSyncs { - return nil - } else if _, stop := stopChan.ReceiveAsync(); stop { - // if sync flow erroring may outrace receiving stop - return nil - } - return workflow.NewContinueAsNewError(ctx, SyncFlowWorkflow, config, options) -}