diff --git a/example/goroutine-leak/cmd/webserver/internal/analytics/analytics.go b/example/goroutine-leak/cmd/webserver/internal/analytics/analytics.go index 562870c..ac9d315 100644 --- a/example/goroutine-leak/cmd/webserver/internal/analytics/analytics.go +++ b/example/goroutine-leak/cmd/webserver/internal/analytics/analytics.go @@ -16,5 +16,10 @@ func NewAnalytics() *Analytics { func (_ *Analytics) Send(ctx context.Context, message string, args ...string) { d := delay.FromContext(ctx) println("Analytics.Send " + d.String()) - <-time.After(d) + + select { + case <-ctx.Done(): + println("Analytics.Send timeout") + case <-time.After(d): + } } diff --git a/example/goroutine-leak/cmd/webserver/internal/service/foo.go b/example/goroutine-leak/cmd/webserver/internal/service/foo.go new file mode 100644 index 0000000..ee36243 --- /dev/null +++ b/example/goroutine-leak/cmd/webserver/internal/service/foo.go @@ -0,0 +1,92 @@ +package service + +import ( + "context" + "errors" + "fmt" +) + +type FooResponse struct { + BarID int + BazID int +} + +type barClient interface { + GetBarID(ctx context.Context, id int) (barID int, err error) +} + +type bazClient interface { + GetBazID(ctx context.Context, id int) (bazID int, err error) +} + +type pool interface { + Enqueue(context.Context, func(poolCtx, taskCtx context.Context)) error +} + +type Foo struct { + pool pool + bar barClient + baz bazClient + // ... +} + +func NewFooService(pool pool, bar barClient, baz bazClient) *Foo { + return &Foo{ + pool: pool, + bar: bar, + baz: baz, + } +} + +type MyFuncType func(ctx context.Context, id int) (barID int, err error) + +func (f *Foo) someFunc(con context.Context, stopCon context.Context, cancelFunc context.CancelFunc, getFunc MyFuncType, id int, result *int, err *error) chan struct{} { + finish := make(chan struct{}) + _ = f.pool.Enqueue(con, func(_, _ context.Context) { + + defer func() { + finish <- struct{}{} + }() + + res, respErr := getFunc(con, id) + if respErr != nil { + err = &respErr + cancelFunc() + return + } + + result = &res + }) + + select { + case <-finish: + return finish + case <-stopCon.Done(): + return finish + } +} + +// GET /api/v1/foo -> json FooResponse +func (f *Foo) Foo(ctx context.Context, id int) (*FooResponse, error) { + defContext, cancel := context.WithCancel(ctx) + + var barID int + var bazID int + + var errorBar error + var errorBaz error + + a := f.someFunc(ctx, defContext, cancel, f.bar.GetBarID, id, &barID, &errorBar) + b := f.someFunc(ctx, defContext, cancel, f.baz.GetBazID, id, &bazID, &errorBaz) + d, e := <-a, <-b + fmt.Print(d, e) + + if errorBar != nil || errorBaz != nil { + return nil, errors.Join(errorBar, errorBaz) + } + + return &FooResponse{ + BarID: barID, + BazID: bazID, + }, nil +} diff --git a/example/goroutine-leak/cmd/webserver/internal/service/foo_test.go b/example/goroutine-leak/cmd/webserver/internal/service/foo_test.go new file mode 100644 index 0000000..41ea285 --- /dev/null +++ b/example/goroutine-leak/cmd/webserver/internal/service/foo_test.go @@ -0,0 +1,27 @@ +package service + +import ( + "context" + "testing" +) + +type bar struct { +} + +func (bar) GetBarID(_ context.Context, _ int) (int, error) { + return 1, nil +} + +type baz struct { +} + +func (baz) GetBazID(_ context.Context, _ int) (int, error) { + return 2, nil +} + +func TestFoo(t *testing.T) { + pool := NewPool(100) + foo := NewFooService(pool, bar{}, baz{}) + + _, _ = foo.Foo(context.Background(), 3) +} diff --git a/example/goroutine-leak/cmd/webserver/internal/service/internal/mock/user.go b/example/goroutine-leak/cmd/webserver/internal/service/internal/mock/user.go new file mode 100644 index 0000000..45a2526 --- /dev/null +++ b/example/goroutine-leak/cmd/webserver/internal/service/internal/mock/user.go @@ -0,0 +1,90 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: user.go + +// Package mock_service is a generated GoMock package. +package mock_service + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockUserRepository is a mock of UserRepository interface. +type MockUserRepository struct { + ctrl *gomock.Controller + recorder *MockUserRepositoryMockRecorder +} + +// MockUserRepositoryMockRecorder is the mock recorder for MockUserRepository. +type MockUserRepositoryMockRecorder struct { + mock *MockUserRepository +} + +// NewMockUserRepository creates a new mock instance. +func NewMockUserRepository(ctrl *gomock.Controller) *MockUserRepository { + mock := &MockUserRepository{ctrl: ctrl} + mock.recorder = &MockUserRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserRepository) EXPECT() *MockUserRepositoryMockRecorder { + return m.recorder +} + +// Create mocks base method. +func (m *MockUserRepository) Create(ctx context.Context, name, email string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", ctx, name, email) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockUserRepositoryMockRecorder) Create(ctx, name, email interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockUserRepository)(nil).Create), ctx, name, email) +} + +// MockAnalytics is a mock of Analytics interface. +type MockAnalytics struct { + ctrl *gomock.Controller + recorder *MockAnalyticsMockRecorder +} + +// MockAnalyticsMockRecorder is the mock recorder for MockAnalytics. +type MockAnalyticsMockRecorder struct { + mock *MockAnalytics +} + +// NewMockAnalytics creates a new mock instance. +func NewMockAnalytics(ctrl *gomock.Controller) *MockAnalytics { + mock := &MockAnalytics{ctrl: ctrl} + mock.recorder = &MockAnalyticsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAnalytics) EXPECT() *MockAnalyticsMockRecorder { + return m.recorder +} + +// Send mocks base method. +func (m *MockAnalytics) Send(ctx context.Context, message string, args ...string) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, message} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Send", varargs...) +} + +// Send indicates an expected call of Send. +func (mr *MockAnalyticsMockRecorder) Send(ctx, message interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, message}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAnalytics)(nil).Send), varargs...) +} diff --git a/example/goroutine-leak/cmd/webserver/internal/service/user.go b/example/goroutine-leak/cmd/webserver/internal/service/user.go index 0f47415..01e4ea4 100644 --- a/example/goroutine-leak/cmd/webserver/internal/service/user.go +++ b/example/goroutine-leak/cmd/webserver/internal/service/user.go @@ -1,6 +1,114 @@ package service -import "context" +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" +) + +//go:generate mockgen -source=$GOFILE -destination=internal/mock/$GOFILE + +type Pool struct { + tasks chan func() + stop atomic.Bool + stopCh chan struct{} + wg sync.WaitGroup + + poolCtx context.Context + cancel context.CancelFunc +} + +func NewPool(count int) *Pool { + poolCtx, cancel := context.WithCancel(context.Background()) + + pool := &Pool{ + tasks: make(chan func(), count), + stop: atomic.Bool{}, + stopCh: make(chan struct{}), + poolCtx: poolCtx, + cancel: cancel, + } + + pool.wg.Add(count) // counter+=100 + + for i := 0; i < count; i++ { + go func() { + defer pool.wg.Done() // counter-- + + for { + select { + case task := <-pool.tasks: + task() // blocking call + case <-pool.stopCh: + return + } + } + }() + } + + return pool +} + +func mergeContext(poolCtx, taskCtx context.Context) context.Context { + taskCtx, cancel := context.WithCancel(taskCtx) + + go func() { + select { + case <-poolCtx.Done(): + cancel() + case <-taskCtx.Done(): + cancel() + } + }() + + return taskCtx +} + +func (p *Pool) Enqueue(ctx context.Context, task func(poolCtx, taskCtx context.Context)) error { + if p.stop.Load() { + return errors.New("Pool is stopping") // ErrPoolIsStopping + } + + taskCtx := context.WithoutCancel(ctx) // <- + + // (2) task(/*give it a time to wrap up*/, taskCtx) // <- + + select { + case p.tasks <- func() { task(p.poolCtx, taskCtx) }: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *Pool) Stop(timeout time.Duration) error { + if !p.stop.CompareAndSwap(false, true) { + return nil + } + + close(p.stopCh) // send a signal to stop + + p.cancel() // <-poolCtx.Done() + + doneCh := make(chan struct{}) + + go func() { + p.wg.Wait() // hangs + close(doneCh) // send a signal that we are done + }() + + select { + case <-time.After(timeout): // timeout + return errors.New("stop timeout") // ErrStopTimout + case <-doneCh: + } + + return nil +} + +//------------------ type UserRepository interface { Create(ctx context.Context, name, email string) (string, error) @@ -13,21 +121,24 @@ type Analytics interface { type UserService struct { repo UserRepository analytics Analytics + pool *Pool } func NewUserService(repo UserRepository, analytics Analytics) *UserService { - return &UserService{repo: repo, analytics: analytics} + return &UserService{ + repo: repo, + analytics: analytics, + pool: NewPool(100), + } } func (s *UserService) Create(ctx context.Context, name, email string) error { // create user in the database userID, _ := s.repo.Create(ctx, name, email) - ctx = context.WithoutCancel(ctx) - - // send analytics event synchronously - // which may cause goroutine and memory leak - go s.analytics.Send(ctx, "user created", userID) + _ = s.pool.Enqueue(ctx, func(_, ctx context.Context) { + s.analytics.Send(ctx, "user created", userID) + }) return nil } diff --git a/example/goroutine-leak/cmd/webserver/internal/service/user_test.go b/example/goroutine-leak/cmd/webserver/internal/service/user_test.go new file mode 100644 index 0000000..bc4de37 --- /dev/null +++ b/example/goroutine-leak/cmd/webserver/internal/service/user_test.go @@ -0,0 +1,37 @@ +package service_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/shilkin/inmemory-workerpool-blogpost/example/goroutine-leak/cmd/webserver/internal/service" + mock "github.com/shilkin/inmemory-workerpool-blogpost/example/goroutine-leak/cmd/webserver/internal/service/internal/mock" + "github.com/stretchr/testify/require" +) + +func TestUserCreate(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) // defer ctrl.Finish() <- linters warning + + repoMock := mock.NewMockUserRepository(ctrl) + analyticsMock := mock.NewMockAnalytics(ctrl) + + var isAnalyticsCalled atomic.Bool + + repoMock.EXPECT().Create(context.Background(), "Jon Doe", "jon.doe@example.com").Return("id", nil) + analyticsMock.EXPECT().Send(gomock.Any(), "user created", "id"). + Do(func(_ context.Context, _, _ string) { + isAnalyticsCalled.Store(true) + }) + + userService := service.NewUserService(repoMock, analyticsMock) + + err := userService.Create(context.Background(), "Jon Doe", "jon.doe@example.com") + require.NoError(t, err) // fail now + + // wait for analytics being called or fail the test + require.Eventually(t, isAnalyticsCalled.Load, time.Second, time.Millisecond) +} diff --git a/go.mod b/go.mod index 3619150..53ff7a6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.4 require ( github.com/daixiang0/gci v0.13.4 + github.com/golang/mock v1.5.0 github.com/golangci/golangci-lint v1.59.1 github.com/google/uuid v1.6.0 github.com/grafana/pyroscope-go v1.1.1 diff --git a/go.sum b/go.sum index 99eaa48..c9258bc 100644 --- a/go.sum +++ b/go.sum @@ -205,6 +205,8 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= +github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=