diff --git a/README.md b/README.md index 590bf95..d290a79 100644 --- a/README.md +++ b/README.md @@ -90,10 +90,38 @@ Otherwise, `Execute` returns the result of the request. If a panic occurs in the request, `CircuitBreaker` handles it as an error and causes the same panic again. + +V2 Implementation +--- + +The [v2 implementation](./v2) provides the same CircuitBreaker logic, but with support for generics in Go. + +This change allows for CircuitBreaker instances to specify the handled type directly, and skips type-casting an `any` +or `interface{}` type into the desired target one. + +This change mostly focuses on the [CircuitBreaker's Execute](./v2/gobreaker.go#L228) method, which accepts an +executable function of a given type: + +```go +func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) +``` + + Example ------- +> *v1 Example* + ```go +import ( + "fmt" + "io" + "log" + "net/http" + + "github.com/sony/gobreaker" +) + var cb *breaker.CircuitBreaker func Get(url string) ([]byte, error) { @@ -119,6 +147,45 @@ func Get(url string) ([]byte, error) { } ``` + +> *v2 Example* + + +```go +import ( + "fmt" + "io" + "log" + "net/http" + + "github.com/sony/gobreaker/v2" +) + +var cb *gobreaker.CircuitBreaker[[]byte] + +func Get(url string) ([]byte, error) { + body, err := cb.Execute(func() ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil + }) + if err != nil { + return nil, err + } + + return body, nil +} +``` + See [example](https://github.com/sony/gobreaker/blob/master/example) for details. License diff --git a/v2/example/http_breaker.go b/v2/example/http_breaker.go new file mode 100644 index 0000000..271ff8f --- /dev/null +++ b/v2/example/http_breaker.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "io" + "log" + "net/http" + + "github.com/sony/gobreaker/v2" +) + +var cb *gobreaker.CircuitBreaker[[]byte] + +func init() { + var st gobreaker.Settings + st.Name = "HTTP GET" + st.ReadyToTrip = func(counts gobreaker.Counts) bool { + failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) + return counts.Requests >= 3 && failureRatio >= 0.6 + } + + cb = gobreaker.NewCircuitBreaker[[]byte](st) +} + +// Get wraps http.Get in CircuitBreaker. +func Get(url string) ([]byte, error) { + body, err := cb.Execute(func() ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil + }) + if err != nil { + return nil, err + } + + return body, nil +} + +func main() { + body, err := Get("http://www.google.com/robots.txt") + if err != nil { + log.Fatal(err) + } + + fmt.Println(string(body)) +} diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..9e1537a --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,11 @@ +module github.com/sony/gobreaker/v2 + +go 1.21 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/gobreaker.go b/v2/gobreaker.go new file mode 100644 index 0000000..2f77a93 --- /dev/null +++ b/v2/gobreaker.go @@ -0,0 +1,382 @@ +// Package gobreaker implements the Circuit Breaker pattern. +// See https://msdn.microsoft.com/en-us/library/dn589784.aspx. +package gobreaker + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// State is a type that represents a state of CircuitBreaker. +type State int + +// These constants are states of CircuitBreaker. +const ( + StateClosed State = iota + StateHalfOpen + StateOpen +) + +var ( + // ErrTooManyRequests is returned when the CB state is half open and the requests count is over the cb maxRequests + ErrTooManyRequests = errors.New("too many requests") + // ErrOpenState is returned when the CB state is open + ErrOpenState = errors.New("circuit breaker is open") +) + +// String implements stringer interface. +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateHalfOpen: + return "half-open" + case StateOpen: + return "open" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +// Counts holds the numbers of requests and their successes/failures. +// CircuitBreaker clears the internal Counts either +// on the change of the state or at the closed-state intervals. +// Counts ignores the results of the requests sent before clearing. +type Counts struct { + Requests uint32 + TotalSuccesses uint32 + TotalFailures uint32 + ConsecutiveSuccesses uint32 + ConsecutiveFailures uint32 +} + +func (c *Counts) onRequest() { + c.Requests++ +} + +func (c *Counts) onSuccess() { + c.TotalSuccesses++ + c.ConsecutiveSuccesses++ + c.ConsecutiveFailures = 0 +} + +func (c *Counts) onFailure() { + c.TotalFailures++ + c.ConsecutiveFailures++ + c.ConsecutiveSuccesses = 0 +} + +func (c *Counts) clear() { + c.Requests = 0 + c.TotalSuccesses = 0 + c.TotalFailures = 0 + c.ConsecutiveSuccesses = 0 + c.ConsecutiveFailures = 0 +} + +// Settings configures CircuitBreaker: +// +// Name is the name of the CircuitBreaker. +// +// MaxRequests is the maximum number of requests allowed to pass through +// when the CircuitBreaker is half-open. +// If MaxRequests is 0, the CircuitBreaker allows only 1 request. +// +// Interval is the cyclic period of the closed state +// for the CircuitBreaker to clear the internal Counts. +// If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. +// +// Timeout is the period of the open state, +// after which the state of the CircuitBreaker becomes half-open. +// If Timeout is less than or equal to 0, the timeout value of the CircuitBreaker is set to 60 seconds. +// +// ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. +// If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. +// If ReadyToTrip is nil, default ReadyToTrip is used. +// Default ReadyToTrip returns true when the number of consecutive failures is more than 5. +// +// OnStateChange is called whenever the state of the CircuitBreaker changes. +// +// IsSuccessful is called with the error returned from a request. +// If IsSuccessful returns true, the error is counted as a success. +// Otherwise the error is counted as a failure. +// If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors. +type Settings struct { + Name string + MaxRequests uint32 + Interval time.Duration + Timeout time.Duration + ReadyToTrip func(counts Counts) bool + OnStateChange func(name string, from State, to State) + IsSuccessful func(err error) bool +} + +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker[T any] struct { + name string + maxRequests uint32 + interval time.Duration + timeout time.Duration + readyToTrip func(counts Counts) bool + isSuccessful func(err error) bool + onStateChange func(name string, from State, to State) + + mutex sync.Mutex + state State + generation uint64 + counts Counts + expiry time.Time +} + +// TwoStepCircuitBreaker is like CircuitBreaker but instead of surrounding a function +// with the breaker functionality, it only checks whether a request can proceed and +// expects the caller to report the outcome in a separate step using a callback. +type TwoStepCircuitBreaker[T any] struct { + cb *CircuitBreaker[T] +} + +// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. +func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] { + cb := new(CircuitBreaker[T]) + + cb.name = st.Name + cb.onStateChange = st.OnStateChange + + if st.MaxRequests == 0 { + cb.maxRequests = 1 + } else { + cb.maxRequests = st.MaxRequests + } + + if st.Interval <= 0 { + cb.interval = defaultInterval + } else { + cb.interval = st.Interval + } + + if st.Timeout <= 0 { + cb.timeout = defaultTimeout + } else { + cb.timeout = st.Timeout + } + + if st.ReadyToTrip == nil { + cb.readyToTrip = defaultReadyToTrip + } else { + cb.readyToTrip = st.ReadyToTrip + } + + if st.IsSuccessful == nil { + cb.isSuccessful = defaultIsSuccessful + } else { + cb.isSuccessful = st.IsSuccessful + } + + cb.toNewGeneration(time.Now()) + + return cb +} + +// NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. +func NewTwoStepCircuitBreaker[T any](st Settings) *TwoStepCircuitBreaker[T] { + return &TwoStepCircuitBreaker[T]{ + cb: NewCircuitBreaker[T](st), + } +} + +const defaultInterval = time.Duration(0) * time.Second +const defaultTimeout = time.Duration(60) * time.Second + +func defaultReadyToTrip(counts Counts) bool { + return counts.ConsecutiveFailures > 5 +} + +func defaultIsSuccessful(err error) bool { + return err == nil +} + +// Name returns the name of the CircuitBreaker. +func (cb *CircuitBreaker[T]) Name() string { + return cb.name +} + +// State returns the current state of the CircuitBreaker. +func (cb *CircuitBreaker[T]) State() State { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, _ := cb.currentState(now) + return state +} + +// Counts returns internal counters +func (cb *CircuitBreaker[T]) Counts() Counts { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + return cb.counts +} + +// Execute runs the given request if the CircuitBreaker accepts it. +// Execute returns an error instantly if the CircuitBreaker rejects the request. +// Otherwise, Execute returns the result of the request. +// If a panic occurs in the request, the CircuitBreaker handles it as an error +// and causes the same panic again. +func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) { + var zero T + + generation, err := cb.beforeRequest() + if err != nil { + return zero, err + } + + defer func() { + e := recover() + if e != nil { + cb.afterRequest(generation, false) + panic(e) + } + }() + + result, err := req() + cb.afterRequest(generation, cb.isSuccessful(err)) + return result, err +} + +// Name returns the name of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker[T]) Name() string { + return tscb.cb.Name() +} + +// State returns the current state of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker[T]) State() State { + return tscb.cb.State() +} + +// Counts returns internal counters +func (tscb *TwoStepCircuitBreaker[T]) Counts() Counts { + return tscb.cb.Counts() +} + +// Allow checks if a new request can proceed. It returns a callback that should be used to +// register the success or failure in a separate step. If the circuit breaker doesn't allow +// requests, it returns an error. +func (tscb *TwoStepCircuitBreaker[T]) Allow() (done func(success bool), err error) { + generation, err := tscb.cb.beforeRequest() + if err != nil { + return nil, err + } + + return func(success bool) { + tscb.cb.afterRequest(generation, success) + }, nil +} + +func (cb *CircuitBreaker[T]) beforeRequest() (uint64, error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + + if state == StateOpen { + return generation, ErrOpenState + } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { + return generation, ErrTooManyRequests + } + + cb.counts.onRequest() + return generation, nil +} + +func (cb *CircuitBreaker[T]) afterRequest(before uint64, success bool) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + if generation != before { + return + } + + if success { + cb.onSuccess(state, now) + } else { + cb.onFailure(state, now) + } +} + +func (cb *CircuitBreaker[T]) onSuccess(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onSuccess() + case StateHalfOpen: + cb.counts.onSuccess() + if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { + cb.setState(StateClosed, now) + } + } +} + +func (cb *CircuitBreaker[T]) onFailure(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onFailure() + if cb.readyToTrip(cb.counts) { + cb.setState(StateOpen, now) + } + case StateHalfOpen: + cb.setState(StateOpen, now) + } +} + +func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64) { + switch cb.state { + case StateClosed: + if !cb.expiry.IsZero() && cb.expiry.Before(now) { + cb.toNewGeneration(now) + } + case StateOpen: + if cb.expiry.Before(now) { + cb.setState(StateHalfOpen, now) + } + } + return cb.state, cb.generation +} + +func (cb *CircuitBreaker[T]) setState(state State, now time.Time) { + if cb.state == state { + return + } + + prev := cb.state + cb.state = state + + cb.toNewGeneration(now) + + if cb.onStateChange != nil { + cb.onStateChange(cb.name, prev, state) + } +} + +func (cb *CircuitBreaker[T]) toNewGeneration(now time.Time) { + cb.generation++ + cb.counts.clear() + + var zero time.Time + switch cb.state { + case StateClosed: + if cb.interval == 0 { + cb.expiry = zero + } else { + cb.expiry = now.Add(cb.interval) + } + case StateOpen: + cb.expiry = now.Add(cb.timeout) + default: // StateHalfOpen + cb.expiry = zero + } +} diff --git a/v2/gobreaker_test.go b/v2/gobreaker_test.go new file mode 100644 index 0000000..fe77bcf --- /dev/null +++ b/v2/gobreaker_test.go @@ -0,0 +1,392 @@ +package gobreaker + +import ( + "fmt" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var defaultCB *CircuitBreaker[bool] +var customCB *CircuitBreaker[bool] + +type StateChange struct { + name string + from State + to State +} + +var stateChange StateChange + +func pseudoSleep(cb *CircuitBreaker[bool], period time.Duration) { + if !cb.expiry.IsZero() { + cb.expiry = cb.expiry.Add(-period) + } +} + +func succeed(cb *CircuitBreaker[bool]) error { + _, err := cb.Execute(func() (bool, error) { return true, nil }) + return err +} + +func succeedLater(cb *CircuitBreaker[bool], delay time.Duration) <-chan error { + ch := make(chan error) + go func() { + _, err := cb.Execute(func() (bool, error) { + time.Sleep(delay) + return true, nil + }) + ch <- err + }() + return ch +} + +func succeed2Step(cb *TwoStepCircuitBreaker[bool]) error { + done, err := cb.Allow() + if err != nil { + return err + } + + done(true) + return nil +} + +func fail(cb *CircuitBreaker[bool]) error { + msg := "fail" + _, err := cb.Execute(func() (bool, error) { return false, fmt.Errorf(msg) }) + if err.Error() == msg { + return nil + } + return err +} + +func fail2Step(cb *TwoStepCircuitBreaker[bool]) error { + done, err := cb.Allow() + if err != nil { + return err + } + + done(false) + return nil +} + +func causePanic(cb *CircuitBreaker[bool]) error { + _, err := cb.Execute(func() (bool, error) { panic("oops"); return false, nil }) + return err +} + +func newCustom() *CircuitBreaker[bool] { + var customSt Settings + customSt.Name = "cb" + customSt.MaxRequests = 3 + customSt.Interval = time.Duration(30) * time.Second + customSt.Timeout = time.Duration(90) * time.Second + customSt.ReadyToTrip = func(counts Counts) bool { + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + + counts.clear() // no effect on customCB.counts + + return numReqs >= 3 && failureRatio >= 0.6 + } + customSt.OnStateChange = func(name string, from State, to State) { + stateChange = StateChange{name, from, to} + } + + return NewCircuitBreaker[bool](customSt) +} + +func newNegativeDurationCB() *CircuitBreaker[bool] { + var negativeSt Settings + negativeSt.Name = "ncb" + negativeSt.Interval = time.Duration(-30) * time.Second + negativeSt.Timeout = time.Duration(-90) * time.Second + + return NewCircuitBreaker[bool](negativeSt) +} + +func init() { + defaultCB = NewCircuitBreaker[bool](Settings{}) + customCB = newCustom() +} + +func TestStateConstants(t *testing.T) { + assert.Equal(t, State(0), StateClosed) + assert.Equal(t, State(1), StateHalfOpen) + assert.Equal(t, State(2), StateOpen) + + assert.Equal(t, StateClosed.String(), "closed") + assert.Equal(t, StateHalfOpen.String(), "half-open") + assert.Equal(t, StateOpen.String(), "open") + assert.Equal(t, State(100).String(), "unknown state: 100") +} + +func TestNewCircuitBreaker(t *testing.T) { + defaultCB := NewCircuitBreaker[bool](Settings{}) + assert.Equal(t, "", defaultCB.name) + assert.Equal(t, uint32(1), defaultCB.maxRequests) + assert.Equal(t, time.Duration(0), defaultCB.interval) + assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout) + assert.NotNil(t, defaultCB.readyToTrip) + assert.Nil(t, defaultCB.onStateChange) + assert.Equal(t, StateClosed, defaultCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.True(t, defaultCB.expiry.IsZero()) + + customCB := newCustom() + assert.Equal(t, "cb", customCB.name) + assert.Equal(t, uint32(3), customCB.maxRequests) + assert.Equal(t, time.Duration(30)*time.Second, customCB.interval) + assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout) + assert.NotNil(t, customCB.readyToTrip) + assert.NotNil(t, customCB.onStateChange) + assert.Equal(t, StateClosed, customCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + + negativeDurationCB := newNegativeDurationCB() + assert.Equal(t, "ncb", negativeDurationCB.name) + assert.Equal(t, uint32(1), negativeDurationCB.maxRequests) + assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval) + assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout) + assert.NotNil(t, negativeDurationCB.readyToTrip) + assert.Nil(t, negativeDurationCB.onStateChange) + assert.Equal(t, StateClosed, negativeDurationCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.counts) + assert.True(t, negativeDurationCB.expiry.IsZero()) +} + +func TestDefaultCircuitBreaker(t *testing.T) { + assert.Equal(t, "", defaultCB.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, fail(defaultCB)) + } + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{5, 0, 5, 0, 5}, defaultCB.counts) + + assert.Nil(t, succeed(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{6, 1, 5, 1, 0}, defaultCB.counts) + + assert.Nil(t, fail(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{7, 1, 6, 0, 1}, defaultCB.counts) + + // StateClosed to StateOpen + for i := 0; i < 5; i++ { + assert.Nil(t, fail(defaultCB)) // 6 consecutive failures + } + assert.Equal(t, StateOpen, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.False(t, defaultCB.expiry.IsZero()) + + assert.Error(t, succeed(defaultCB)) + assert.Error(t, fail(defaultCB)) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + + pseudoSleep(defaultCB, time.Duration(59)*time.Second) + assert.Equal(t, StateOpen, defaultCB.State()) + + // StateOpen to StateHalfOpen + pseudoSleep(defaultCB, time.Duration(1)*time.Second) // over Timeout + assert.Equal(t, StateHalfOpen, defaultCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + + // StateHalfOpen to StateOpen + assert.Nil(t, fail(defaultCB)) + assert.Equal(t, StateOpen, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.False(t, defaultCB.expiry.IsZero()) + + // StateOpen to StateHalfOpen + pseudoSleep(defaultCB, time.Duration(60)*time.Second) + assert.Equal(t, StateHalfOpen, defaultCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + + // StateHalfOpen to StateClosed + assert.Nil(t, succeed(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.True(t, defaultCB.expiry.IsZero()) +} + +func TestCustomCircuitBreaker(t *testing.T) { + assert.Equal(t, "cb", customCB.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, succeed(customCB)) + assert.Nil(t, fail(customCB)) + } + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{10, 5, 5, 0, 1}, customCB.counts) + + pseudoSleep(customCB, time.Duration(29)*time.Second) + assert.Nil(t, succeed(customCB)) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{11, 6, 5, 1, 0}, customCB.counts) + + pseudoSleep(customCB, time.Duration(1)*time.Second) // over Interval + assert.Nil(t, fail(customCB)) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{1, 0, 1, 0, 1}, customCB.counts) + + // StateClosed to StateOpen + assert.Nil(t, succeed(customCB)) + assert.Nil(t, fail(customCB)) // failure ratio: 2/3 >= 0.6 + assert.Equal(t, StateOpen, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) + + // StateOpen to StateHalfOpen + pseudoSleep(customCB, time.Duration(90)*time.Second) + assert.Equal(t, StateHalfOpen, customCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) + + assert.Nil(t, succeed(customCB)) + assert.Nil(t, succeed(customCB)) + assert.Equal(t, StateHalfOpen, customCB.State()) + assert.Equal(t, Counts{2, 2, 0, 2, 0}, customCB.counts) + + // StateHalfOpen to StateClosed + ch := succeedLater(customCB, time.Duration(100)*time.Millisecond) // 3 consecutive successes + time.Sleep(time.Duration(50) * time.Millisecond) + assert.Equal(t, Counts{3, 2, 0, 2, 0}, customCB.counts) + assert.Error(t, succeed(customCB)) // over MaxRequests + assert.Nil(t, <-ch) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) +} + +func TestTwoStepCircuitBreaker(t *testing.T) { + tscb := NewTwoStepCircuitBreaker[bool](Settings{Name: "tscb"}) + assert.Equal(t, "tscb", tscb.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, fail2Step(tscb)) + } + + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{5, 0, 5, 0, 5}, tscb.cb.counts) + + assert.Nil(t, succeed2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{6, 1, 5, 1, 0}, tscb.cb.counts) + + assert.Nil(t, fail2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{7, 1, 6, 0, 1}, tscb.cb.counts) + + // StateClosed to StateOpen + for i := 0; i < 5; i++ { + assert.Nil(t, fail2Step(tscb)) // 6 consecutive failures + } + assert.Equal(t, StateOpen, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.False(t, tscb.cb.expiry.IsZero()) + + assert.Error(t, succeed2Step(tscb)) + assert.Error(t, fail2Step(tscb)) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + + pseudoSleep(tscb.cb, time.Duration(59)*time.Second) + assert.Equal(t, StateOpen, tscb.State()) + + // StateOpen to StateHalfOpen + pseudoSleep(tscb.cb, time.Duration(1)*time.Second) // over Timeout + assert.Equal(t, StateHalfOpen, tscb.State()) + assert.True(t, tscb.cb.expiry.IsZero()) + + // StateHalfOpen to StateOpen + assert.Nil(t, fail2Step(tscb)) + assert.Equal(t, StateOpen, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.False(t, tscb.cb.expiry.IsZero()) + + // StateOpen to StateHalfOpen + pseudoSleep(tscb.cb, time.Duration(60)*time.Second) + assert.Equal(t, StateHalfOpen, tscb.State()) + assert.True(t, tscb.cb.expiry.IsZero()) + + // StateHalfOpen to StateClosed + assert.Nil(t, succeed2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.True(t, tscb.cb.expiry.IsZero()) +} + +func TestPanicInRequest(t *testing.T) { + assert.Panics(t, func() { causePanic(defaultCB) }) + assert.Equal(t, Counts{1, 0, 1, 0, 1}, defaultCB.counts) +} + +func TestGeneration(t *testing.T) { + pseudoSleep(customCB, time.Duration(29)*time.Second) + assert.Nil(t, succeed(customCB)) + ch := succeedLater(customCB, time.Duration(1500)*time.Millisecond) + time.Sleep(time.Duration(500) * time.Millisecond) + assert.Equal(t, Counts{2, 1, 0, 1, 0}, customCB.counts) + + time.Sleep(time.Duration(500) * time.Millisecond) // over Interval + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + + // the request from the previous generation has no effect on customCB.counts + assert.Nil(t, <-ch) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) +} + +func TestCustomIsSuccessful(t *testing.T) { + isSuccessful := func(error) bool { + return true + } + cb := NewCircuitBreaker[bool](Settings{IsSuccessful: isSuccessful}) + + for i := 0; i < 5; i++ { + assert.Nil(t, fail(cb)) + } + assert.Equal(t, StateClosed, cb.State()) + assert.Equal(t, Counts{5, 5, 0, 5, 0}, cb.counts) + + cb.counts.clear() + + cb.isSuccessful = func(err error) bool { + return err == nil + } + for i := 0; i < 6; i++ { + assert.Nil(t, fail(cb)) + } + assert.Equal(t, StateOpen, cb.State()) + +} + +func TestCircuitBreakerInParallel(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + ch := make(chan error) + + const numReqs = 10000 + routine := func() { + for i := 0; i < numReqs; i++ { + ch <- succeed(customCB) + } + } + + const numRoutines = 10 + for i := 0; i < numRoutines; i++ { + go routine() + } + + total := uint32(numReqs * numRoutines) + for i := uint32(0); i < total; i++ { + err := <-ch + assert.Nil(t, err) + } + assert.Equal(t, Counts{total, total, 0, total, 0}, customCB.counts) +}