Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshiyukiMineo committed Nov 16, 2024
1 parent d6880cf commit c1a3eca
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
22 changes: 11 additions & 11 deletions v2/distributed_gobreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ type SharedStateStore interface {
// DistributedCircuitBreaker extends CircuitBreaker with distributed state storage
type DistributedCircuitBreaker[T any] struct {
*CircuitBreaker[T]
cacheClient SharedStateStore
store SharedStateStore
}

// NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings
func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] {
cb := NewCircuitBreaker[T](settings)
return &DistributedCircuitBreaker[T]{
CircuitBreaker: cb,
cacheClient: storageClient,
store: storageClient,
}
}

func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
if rcb.cacheClient == nil {
if rcb.store == nil {
return rcb.CircuitBreaker.State()
}

state, err := rcb.cacheClient.GetState(ctx)
state, err := rcb.store.GetState(ctx)
if err != nil {
// Fallback to in-memory state if Storage fails
return rcb.CircuitBreaker.State()
Expand All @@ -51,7 +51,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
// Update the state in Storage if it has changed
if currentState != state.State {
state.State = currentState
if err := rcb.cacheClient.SetState(ctx, state); err != nil {
if err := rcb.store.SetState(ctx, state); err != nil {
// Log the error, but continue with the current state
fmt.Printf("Failed to update state in storage: %v\n", err)
}
Expand All @@ -62,7 +62,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {

// Execute runs the given request if the DistributedCircuitBreaker accepts it
func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) {
if rcb.cacheClient == nil {
if rcb.store == nil {
return rcb.CircuitBreaker.Execute(req)
}
generation, err := rcb.beforeRequest(ctx)
Expand All @@ -86,7 +86,7 @@ func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func()
}

func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) {
state, err := rcb.cacheClient.GetState(ctx)
state, err := rcb.store.GetState(ctx)
if err != nil {
return 0, err
}
Expand All @@ -95,7 +95,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin

if currentState != state.State {
rcb.setState(&state, currentState, now)
err = rcb.cacheClient.SetState(ctx, state)
err = rcb.store.SetState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -108,7 +108,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

state.Counts.onRequest()
err = rcb.cacheClient.SetState(ctx, state)
err = rcb.store.SetState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -117,7 +117,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) {
state, err := rcb.cacheClient.GetState(ctx)
state, err := rcb.store.GetState(ctx)
if err != nil {
return
}
Expand All @@ -133,7 +133,7 @@ func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor
rcb.onFailure(&state, currentState, now)
}

rcb.cacheClient.SetState(ctx, state)
rcb.store.SetState(ctx, state)
}

func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) {
Expand Down
16 changes: 8 additions & 8 deletions v2/distributed_gobreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir
}

func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) {
state, _ := rcb.cacheClient.GetState(ctx)
state, _ := rcb.store.GetState(ctx)

state.Expiry = state.Expiry.Add(-period)
// Reset counts if the interval has passed
if time.Now().After(state.Expiry) {
state.Counts = Counts{}
}
rcb.cacheClient.SetState(ctx, state)
rcb.store.SetState(ctx, state)
}

func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error {
Expand Down Expand Up @@ -174,11 +174,11 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) {
assert.Nil(t, successRequest(ctx, rcb))
}

state, _ := rcb.cacheClient.GetState(ctx)
state, _ := rcb.store.GetState(ctx)
assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts)

assert.Nil(t, failRequest(ctx, rcb))
state, _ = rcb.cacheClient.GetState(ctx)
state, _ = rcb.store.GetState(ctx)
assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts)
}

Expand All @@ -191,7 +191,7 @@ func TestDistributedCircuitBreakerFallback(t *testing.T) {
// Test when Storage is unavailable
mr.Close() // Simulate Storage being unavailable

rcb.cacheClient = nil
rcb.store = nil

state := rcb.State(ctx)
assert.Equal(t, StateClosed, state, "Should fallback to in-memory state when Storage is unavailable")
Expand Down Expand Up @@ -240,14 +240,14 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
assert.NoError(t, failRequest(ctx, customRCB))
}

state, err := customRCB.cacheClient.GetState(ctx)
state, err := customRCB.store.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, StateClosed, state.State)
assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts)

// Perform one more successful request
assert.NoError(t, successRequest(ctx, customRCB))
state, err = customRCB.cacheClient.GetState(ctx)
state, err = customRCB.store.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts)

Expand All @@ -262,7 +262,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
// Check if the circuit breaker is now open
assert.Equal(t, StateOpen, customRCB.State(ctx))

state, err = customRCB.cacheClient.GetState(ctx)
state, err = customRCB.store.GetState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts)
})
Expand Down

0 comments on commit c1a3eca

Please sign in to comment.