diff --git a/v2/gobreaker.go b/v2/gobreaker.go index c0382d1..b4d925e 100644 --- a/v2/gobreaker.go +++ b/v2/gobreaker.go @@ -113,15 +113,48 @@ type Settings struct { IsSuccessful func(err error) bool } -// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. -type CircuitBreaker[T any] struct { - name string +// TrackingSettings configures Tracking: +// +// MaxRequests is the maximum number of requests allowed to pass through +// when the CircuitBreaker is half-open. +// If MaxRequests is 0, the CircuitBreaker allows only 1 request. +// +// Interval is the cyclic period of the closed state +// for the CircuitBreaker to clear the internal Counts. +// If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. +// +// Timeout is the period of the open state, +// after which the state of the CircuitBreaker becomes half-open. +// If Timeout is less than or equal to 0, the timeout value of the CircuitBreaker is set to 60 seconds. +// +// ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. +// If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. +// If ReadyToTrip is nil, default ReadyToTrip is used. +// Default ReadyToTrip returns true when the number of consecutive failures is more than 5. +// +// OnStateChange is called whenever the state of the CircuitBreaker changes. +// +// IsSuccessful is called with the error returned from a request. +// If IsSuccessful returns true, the error is counted as a success. +// Otherwise the error is counted as a failure. +// If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors. +type TrackingSettings struct { + MaxRequests uint32 + Interval time.Duration + Timeout time.Duration + ReadyToTrip func(counts Counts) bool + OnStateChange func(from State, to State) + IsSuccessful func(err error) bool +} + +// Tracking implements a state machine to implement CircuitBreakers. +type Tracking struct { maxRequests uint32 interval time.Duration timeout time.Duration readyToTrip func(counts Counts) bool - isSuccessful func(err error) bool - onStateChange func(name string, from State, to State) + IsSuccessful func(err error) bool + onStateChange func(from State, to State) mutex sync.Mutex state State @@ -130,6 +163,12 @@ type CircuitBreaker[T any] struct { expiry time.Time } +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker[T any] struct { + name string + tracking *Tracking +} + // TwoStepCircuitBreaker is like CircuitBreaker but instead of surrounding a function // with the breaker functionality, it only checks whether a request can proceed and // expects the caller to report the outcome in a separate step using a callback. @@ -142,6 +181,41 @@ func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] { cb := new(CircuitBreaker[T]) cb.name = st.Name + cb.tracking = NewTracking(TrackingSettings{ + MaxRequests: st.MaxRequests, + Interval: st.Interval, + Timeout: st.Timeout, + ReadyToTrip: st.ReadyToTrip, + OnStateChange: func(from, to State) { + if st.OnStateChange != nil { + st.OnStateChange(cb.name, from, to) + } + }, + }) + + return cb +} + +// NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. +func NewTwoStepCircuitBreaker[T any](st Settings) *TwoStepCircuitBreaker[T] { + return &TwoStepCircuitBreaker[T]{ + cb: NewCircuitBreaker[T](st), + } +} + +const defaultInterval = time.Duration(0) * time.Second +const defaultTimeout = time.Duration(60) * time.Second + +func defaultReadyToTrip(counts Counts) bool { + return counts.ConsecutiveFailures > 5 +} + +func defaultIsSuccessful(err error) bool { + return err == nil +} + +func NewTracking(st TrackingSettings) *Tracking { + cb := Tracking{} cb.onStateChange = st.OnStateChange if st.MaxRequests == 0 { @@ -169,112 +243,16 @@ func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] { } if st.IsSuccessful == nil { - cb.isSuccessful = defaultIsSuccessful + cb.IsSuccessful = defaultIsSuccessful } else { - cb.isSuccessful = st.IsSuccessful + cb.IsSuccessful = st.IsSuccessful } cb.toNewGeneration(time.Now()) - - return cb -} - -// NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. -func NewTwoStepCircuitBreaker[T any](st Settings) *TwoStepCircuitBreaker[T] { - return &TwoStepCircuitBreaker[T]{ - cb: NewCircuitBreaker[T](st), - } -} - -const defaultInterval = time.Duration(0) * time.Second -const defaultTimeout = time.Duration(60) * time.Second - -func defaultReadyToTrip(counts Counts) bool { - return counts.ConsecutiveFailures > 5 -} - -func defaultIsSuccessful(err error) bool { - return err == nil -} - -// Name returns the name of the CircuitBreaker. -func (cb *CircuitBreaker[T]) Name() string { - return cb.name -} - -// State returns the current state of the CircuitBreaker. -func (cb *CircuitBreaker[T]) State() State { - cb.mutex.Lock() - defer cb.mutex.Unlock() - - now := time.Now() - state, _ := cb.currentState(now) - return state -} - -// Counts returns internal counters -func (cb *CircuitBreaker[T]) Counts() Counts { - cb.mutex.Lock() - defer cb.mutex.Unlock() - - return cb.counts -} - -// Execute runs the given request if the CircuitBreaker accepts it. -// Execute returns an error instantly if the CircuitBreaker rejects the request. -// Otherwise, Execute returns the result of the request. -// If a panic occurs in the request, the CircuitBreaker handles it as an error -// and causes the same panic again. -func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) { - generation, err := cb.beforeRequest() - if err != nil { - var defaultValue T - return defaultValue, err - } - - defer func() { - e := recover() - if e != nil { - cb.afterRequest(generation, false) - panic(e) - } - }() - - result, err := req() - cb.afterRequest(generation, cb.isSuccessful(err)) - return result, err -} - -// Name returns the name of the TwoStepCircuitBreaker. -func (tscb *TwoStepCircuitBreaker[T]) Name() string { - return tscb.cb.Name() -} - -// State returns the current state of the TwoStepCircuitBreaker. -func (tscb *TwoStepCircuitBreaker[T]) State() State { - return tscb.cb.State() -} - -// Counts returns internal counters -func (tscb *TwoStepCircuitBreaker[T]) Counts() Counts { - return tscb.cb.Counts() -} - -// Allow checks if a new request can proceed. It returns a callback that should be used to -// register the success or failure in a separate step. If the circuit breaker doesn't allow -// requests, it returns an error. -func (tscb *TwoStepCircuitBreaker[T]) Allow() (done func(success bool), err error) { - generation, err := tscb.cb.beforeRequest() - if err != nil { - return nil, err - } - - return func(success bool) { - tscb.cb.afterRequest(generation, success) - }, nil + return &cb } -func (cb *CircuitBreaker[T]) beforeRequest() (uint64, error) { +func (cb *Tracking) BeforeRequest() (uint64, error) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -291,7 +269,7 @@ func (cb *CircuitBreaker[T]) beforeRequest() (uint64, error) { return generation, nil } -func (cb *CircuitBreaker[T]) afterRequest(before uint64, success bool) { +func (cb *Tracking) AfterRequest(before uint64, success bool) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -308,7 +286,7 @@ func (cb *CircuitBreaker[T]) afterRequest(before uint64, success bool) { } } -func (cb *CircuitBreaker[T]) onSuccess(state State, now time.Time) { +func (cb *Tracking) onSuccess(state State, now time.Time) { switch state { case StateClosed: cb.counts.onSuccess() @@ -320,7 +298,7 @@ func (cb *CircuitBreaker[T]) onSuccess(state State, now time.Time) { } } -func (cb *CircuitBreaker[T]) onFailure(state State, now time.Time) { +func (cb *Tracking) onFailure(state State, now time.Time) { switch state { case StateClosed: cb.counts.onFailure() @@ -332,7 +310,7 @@ func (cb *CircuitBreaker[T]) onFailure(state State, now time.Time) { } } -func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64) { +func (cb *Tracking) currentState(now time.Time) (State, uint64) { switch cb.state { case StateClosed: if !cb.expiry.IsZero() && cb.expiry.Before(now) { @@ -346,7 +324,7 @@ func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64) { return cb.state, cb.generation } -func (cb *CircuitBreaker[T]) setState(state State, now time.Time) { +func (cb *Tracking) setState(state State, now time.Time) { if cb.state == state { return } @@ -357,11 +335,11 @@ func (cb *CircuitBreaker[T]) setState(state State, now time.Time) { cb.toNewGeneration(now) if cb.onStateChange != nil { - cb.onStateChange(cb.name, prev, state) + cb.onStateChange(prev, state) } } -func (cb *CircuitBreaker[T]) toNewGeneration(now time.Time) { +func (cb *Tracking) toNewGeneration(now time.Time) { cb.generation++ cb.counts.clear() @@ -379,3 +357,90 @@ func (cb *CircuitBreaker[T]) toNewGeneration(now time.Time) { cb.expiry = zero } } + +// State returns the current state of the CircuitBreaker. +func (cb *Tracking) State() State { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, _ := cb.currentState(now) + return state +} + +// Counts returns internal counters +func (cb *Tracking) Counts() Counts { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + return cb.counts +} + +// Name returns the name of the CircuitBreaker. +func (cb *CircuitBreaker[T]) Name() string { + return cb.name +} + +// State returns the current state of the CircuitBreaker. +func (cb *CircuitBreaker[T]) State() State { + return cb.tracking.State() +} + +// Counts returns internal counters +func (cb *CircuitBreaker[T]) Counts() Counts { + return cb.tracking.Counts() +} + +// Execute runs the given request if the CircuitBreaker accepts it. +// Execute returns an error instantly if the CircuitBreaker rejects the request. +// Otherwise, Execute returns the result of the request. +// If a panic occurs in the request, the CircuitBreaker handles it as an error +// and causes the same panic again. +func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) { + generation, err := cb.tracking.BeforeRequest() + if err != nil { + var defaultValue T + return defaultValue, err + } + + defer func() { + e := recover() + if e != nil { + cb.tracking.AfterRequest(generation, false) + panic(e) + } + }() + + result, err := req() + cb.tracking.AfterRequest(generation, cb.tracking.IsSuccessful(err)) + return result, err +} + +// Name returns the name of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker[T]) Name() string { + return tscb.cb.Name() +} + +// State returns the current state of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker[T]) State() State { + return tscb.cb.State() +} + +// Counts returns internal counters +func (tscb *TwoStepCircuitBreaker[T]) Counts() Counts { + return tscb.cb.Counts() +} + +// Allow checks if a new request can proceed. It returns a callback that should be used to +// register the success or failure in a separate step. If the circuit breaker doesn't allow +// requests, it returns an error. +func (tscb *TwoStepCircuitBreaker[T]) Allow() (done func(success bool), err error) { + generation, err := tscb.cb.tracking.BeforeRequest() + if err != nil { + return nil, err + } + + return func(success bool) { + tscb.cb.tracking.AfterRequest(generation, success) + }, nil +}