diff --git a/internal/errgroup/errgroup.go b/internal/errgroup/errgroup.go index 078c970..ac4acfa 100644 --- a/internal/errgroup/errgroup.go +++ b/internal/errgroup/errgroup.go @@ -30,10 +30,7 @@ func WithContext(ctx context.Context) (*Group, context.Context) { // Go runs the provided f function in a dedicated goroutine and waits for its // completion or for the parent context cancellation. func (g *Group) Go(f func() error) { - if g.grp == nil { - g.grp = &errgroup.Group{} - } - g.grp.Go(g.wrap(f)) + g.getErrGroup().Go(g.wrap(f)) } // Wait blocks until all function calls from the Go method have returned, then @@ -41,10 +38,7 @@ func (g *Group) Go(f func() error) { // If the error group was created via WithContext then the Wait returns error // of cancelled parent context prior any functions calls complete. func (g *Group) Wait() error { - if g.grp == nil { - g.grp = &errgroup.Group{} - } - return g.grp.Wait() + return g.getErrGroup().Wait() } // SetLimit limits the number of active goroutines in this group to at most n. @@ -55,10 +49,7 @@ func (g *Group) Wait() error { // // The limit must not be modified while any goroutines in the group are active. func (g *Group) SetLimit(n int) { - if g.grp == nil { - g.grp = &errgroup.Group{} - } - g.grp.SetLimit(n) + g.getErrGroup().SetLimit(n) } // TryGo calls the given function in a new goroutine only if the number of @@ -66,10 +57,7 @@ func (g *Group) SetLimit(n int) { // // The return value reports whether the goroutine was started. func (g *Group) TryGo(f func() error) bool { - if g.grp == nil { - g.grp = &errgroup.Group{} - } - return g.grp.TryGo(g.wrap(f)) + return g.getErrGroup().TryGo(g.wrap(f)) } func (g *Group) wrap(f func() error) func() error { @@ -106,3 +94,14 @@ func (g *Group) wrap(f func() error) func() error { } } } + +// The getErrGroup returns actual x/sync/errgroup.Group. +// If the group is not allocated it would implicitly allocate it. +// Thats allows the internal/errgroup.Group be fully +// compatible to x/sync/errgroup.Group +func (g *Group) getErrGroup() *errgroup.Group { + if g.grp == nil { + g.grp = &errgroup.Group{} + } + return g.grp +}