Skip to content

Commit

Permalink
fix(action): unexpected behaviour when watching non-existing Actions
Browse files Browse the repository at this point in the history
When passed an action with an unknown ID, the
`ActionClient.WatchProgress()` and `ActionClient.WatchOverallProgress()`
methods experience unexpected behaviour:

- `WatchProgress()` would panic with a nil pointer dereference when
  accessing `a.Status`
- `WatchOverallProgress()` would go into an infinite loop, never
  finishing or closing the channels.
  • Loading branch information
apricote committed Aug 8, 2023
1 parent 913bf74 commit 1c9485e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
11 changes: 11 additions & 0 deletions hcloud/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti
errCh <- err
return
}
if len(as) == 0 {
// No actions returned for the provided IDs, they do not exist in the API.
// We need to catch and fail early for this, otherwise the loop will continue
// indefinitely.
errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID)
return
}

for _, a := range as {
switch a.Status {
Expand Down Expand Up @@ -282,6 +289,10 @@ func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-cha
errCh <- err
return
}
if a == nil {
errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID)
return
}

switch a.Status {
case ActionStatusRunning:
Expand Down
132 changes: 132 additions & 0 deletions hcloud/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"net/http"
"reflect"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -286,6 +287,85 @@ func TestActionClientWatchOverallProgress(t *testing.T) {
}
}

func TestActionClientWatchOverallProgressInvalidID(t *testing.T) {
env := newTestEnv()
defer env.Teardown()

callCount := 0

env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) {
callCount++
var actions []schema.Action

switch callCount {
case 1:
default:
t.Errorf("unexpected number of calls to the test server: %v", callCount)
}

w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(struct {
Actions []schema.Action `json:"actions"`
Meta schema.Meta `json:"meta"`
}{
Actions: actions,
Meta: schema.Meta{
Pagination: &schema.MetaPagination{
Page: 1,
LastPage: 1,
PerPage: len(actions),
TotalEntries: len(actions),
},
},
})
})

actions := []*Action{
{
ID: 1,
Status: ActionStatusRunning,
},
}

ctx := context.Background()
progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions)
progressUpdates := []int{}
errs := []error{}

moreProgress, moreErrors := true, true

for moreProgress || moreErrors {
var progress int
var err error

select {
case progress, moreProgress = <-progressCh:
if moreProgress {
progressUpdates = append(progressUpdates, progress)
}
case err, moreErrors = <-errCh:
if moreErrors {
errs = append(errs, err)
}
}
}

if len(errs) != 1 {
t.Fatalf("expected to receive one error: %v", errs)
}

err := errs[0]

if !strings.HasPrefix(err.Error(), "failed to wait for actions") {
t.Fatalf("expected failed to wait for actions error, but got: %#v", err)
}

expectedProgressUpdates := []int{}
if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) {
t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates)
}
}

func TestActionClientWatchProgress(t *testing.T) {
env := newTestEnv()
defer env.Teardown()
Expand Down Expand Up @@ -385,3 +465,55 @@ func TestActionClientWatchProgressError(t *testing.T) {
t.Fatal("expected an error")
}
}

func TestActionClientWatchProgressInvalidID(t *testing.T) {
env := newTestEnv()
defer env.Teardown()

callCount := 0

env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
switch callCount {
case 1:
_ = json.NewEncoder(w).Encode(schema.ErrorResponse{
Error: schema.Error{
Code: string(ErrorCodeNotFound),
Message: "action with ID '1' not found",
Details: nil,
},
})
default:
t.Errorf("unexpected number of calls to the test server: %v", callCount)
}
})
action := &Action{
ID: 1,
}

ctx := context.Background()
progressCh, errCh := env.Client.Action.WatchProgress(ctx, action)
var (
progressUpdates []int
err error
)

loop:
for {
select {
case progress := <-progressCh:
progressUpdates = append(progressUpdates, progress)
case err = <-errCh:
break loop
}
}

if !strings.HasPrefix(err.Error(), "failed to wait for action") {
t.Fatalf("expected failed to wait for action error, but got: %#v", err)
}
if len(progressUpdates) != 0 {
t.Fatalf("unexpected progress updates: %v", progressUpdates)
}
}

0 comments on commit 1c9485e

Please sign in to comment.