Skip to content

Commit

Permalink
Minor improvements for implicit errgroup initialization
Browse files Browse the repository at this point in the history
Should be coverage friendly
  • Loading branch information
egorse committed Dec 14, 2023
1 parent a5b84a6 commit 4fa0b49
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions internal/errgroup/errgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,15 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
// Go runs the provided f function in a dedicated goroutine and waits for its
// completion or for the parent context cancellation.
func (g *Group) Go(f func() error) {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
g.grp.Go(g.wrap(f))
g.getErrGroup().Go(g.wrap(f))
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
// If the error group was created via WithContext then the Wait returns error
// of cancelled parent context prior any functions calls complete.
func (g *Group) Wait() error {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp.Wait()
return g.getErrGroup().Wait()
}

// SetLimit limits the number of active goroutines in this group to at most n.
Expand All @@ -55,21 +49,15 @@ func (g *Group) Wait() error {
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
g.grp.SetLimit(n)
g.getErrGroup().SetLimit(n)
}

// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp.TryGo(g.wrap(f))
return g.getErrGroup().TryGo(g.wrap(f))
}

func (g *Group) wrap(f func() error) func() error {
Expand Down Expand Up @@ -106,3 +94,14 @@ func (g *Group) wrap(f func() error) func() error {
}
}
}

// The getErrGroup returns actual x/sync/errgroup.Group.
// If the group is not allocated it would implicitly allocate it.
// Thats allows the internal/errgroup.Group be fully
// compatible to x/sync/errgroup.Group
func (g *Group) getErrGroup() *errgroup.Group {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp
}

0 comments on commit 4fa0b49

Please sign in to comment.