diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go index 1c35794658d49..56879a3455f26 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go @@ -134,7 +134,11 @@ func Test_preprocessor_preProcessGraph(t *testing.T) { }}) gotStages := pre.preProcessGraph(test.input, nil) - if diff := cmp.Diff(test.wantStages, gotStages, cmp.AllowUnexported(stage{}, link{}), cmpopts.EquateEmpty()); diff != "" { + if diff := cmp.Diff(test.wantStages, gotStages, + cmp.AllowUnexported(stage{}, link{}), + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(stage{}, "baseProgTick"), + ); diff != "" { t.Errorf("preProcessGraph(%q) stages diff (-want,+got)\n%v", test.name, diff) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index da23ca8ccce1e..7c55f6d9a23f7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "sync/atomic" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" @@ -76,9 +77,32 @@ type stage struct { SinkToPCollection map[string]string OutputsToCoders map[string]engine.PColInfo + + // Stage specific progress and splitting interval. + baseProgTick atomic.Value // time.Duration +} + +// The minimum and maximum durations between each ProgressBundleRequest and split evaluation. +const ( + minimumProgTick = 10 * time.Millisecond + maximumProgTick = 30 * time.Second +) + +func clampTick(dur time.Duration) time.Duration { + switch { + case dur < minimumProgTick: + return minimumProgTick + case dur > maximumProgTick: + return maximumProgTick + default: + return dur + } } func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) (err error) { + if s.baseProgTick.Load() == nil { + s.baseProgTick.Store(minimumProgTick) + } defer func() { // Convert execution panics to errors to fail the bundle. if e := recover(); e != nil { @@ -142,7 +166,9 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c previousTotalCount := int64(-2) // Total count of all pcollection elements. unsplit := true - progTick := time.NewTicker(100 * time.Millisecond) + baseTick := s.baseProgTick.Load().(time.Duration) + ticked := false + progTick := time.NewTicker(baseTick) defer progTick.Stop() var dataFinished, bundleFinished bool // If we have no data outputs, we still need to have progress & splits @@ -170,6 +196,7 @@ progress: break progress // exit progress loop on close. } case <-progTick.C: + ticked = true resp, err := b.Progress(ctx, wk) if err != nil { slog.Debug("SDK Error from progress, aborting progress", "bundle", rb, "error", err.Error()) @@ -196,6 +223,7 @@ progress: unsplit = false continue progress } + // TODO sort out rescheduling primary Roots on bundle failure. var residuals []engine.Residual for _, rr := range sr.GetResidualRoots() { @@ -220,12 +248,28 @@ progress: Data: residuals, }) } + + // Any split means we're processing slower than desired, but splitting should increase + // throughput. Back off for this and other bundles for this stage + baseTime := s.baseProgTick.Load().(time.Duration) + newTime := clampTick(baseTime * 4) + if s.baseProgTick.CompareAndSwap(baseTime, newTime) { + progTick.Reset(newTime) + } else { + progTick.Reset(s.baseProgTick.Load().(time.Duration)) + } } else { previousIndex = index["index"] previousTotalCount = index["totalCount"] } } } + // If we never received any progress ticks, we may have too long a time, shrink it for new runs instead. + if !ticked { + newTick := clampTick(baseTick / 2) + // If it's otherwise unchanged, apply the new duration. + s.baseProgTick.CompareAndSwap(baseTick, newTick) + } // Tentative Data is ready, commit it to the main datastore. slog.Debug("Execute: committing data", "bundle", rb, slog.Any("outputsWithData", maps.Keys(b.OutputData.Raw)), slog.Any("outputs", maps.Keys(s.OutputsToCoders)))