From 1c9485e4c0e5d0f7a682d8a2cbb12c93ad87c80e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20T=C3=B6lle?= Date: Tue, 8 Aug 2023 12:56:15 +0200 Subject: [PATCH] fix(action): unexpected behaviour when watching non-existing Actions When passed an action with an unknown ID, the `ActionClient.WatchProgress()` and `ActionClient.WatchOverallProgress()` methods experience unexpected behaviour: - `WatchProgress()` would panic with a nil pointer dereference when accessing `a.Status` - `WatchOverallProgress()` would go into an infinite loop, never finishing or closing the channels. --- hcloud/action.go | 11 ++++ hcloud/action_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/hcloud/action.go b/hcloud/action.go index e4cf712b..7e80dee5 100644 --- a/hcloud/action.go +++ b/hcloud/action.go @@ -216,6 +216,13 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti errCh <- err return } + if len(as) == 0 { + // No actions returned for the provided IDs, they do not exist in the API. + // We need to catch and fail early for this, otherwise the loop will continue + // indefinitely. + errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID) + return + } for _, a := range as { switch a.Status { @@ -282,6 +289,10 @@ func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-cha errCh <- err return } + if a == nil { + errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID) + return + } switch a.Status { case ActionStatusRunning: diff --git a/hcloud/action_test.go b/hcloud/action_test.go index f19f26a0..b7c42e9b 100644 --- a/hcloud/action_test.go +++ b/hcloud/action_test.go @@ -6,6 +6,7 @@ import ( "errors" "net/http" "reflect" + "strings" "testing" "time" @@ -286,6 +287,85 @@ func TestActionClientWatchOverallProgress(t *testing.T) { } } +func TestActionClientWatchOverallProgressInvalidID(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { + callCount++ + var actions []schema.Action + + switch callCount { + case 1: + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(struct { + Actions []schema.Action `json:"actions"` + Meta schema.Meta `json:"meta"` + }{ + Actions: actions, + Meta: schema.Meta{ + Pagination: &schema.MetaPagination{ + Page: 1, + LastPage: 1, + PerPage: len(actions), + TotalEntries: len(actions), + }, + }, + }) + }) + + actions := []*Action{ + { + ID: 1, + Status: ActionStatusRunning, + }, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions) + progressUpdates := []int{} + errs := []error{} + + moreProgress, moreErrors := true, true + + for moreProgress || moreErrors { + var progress int + var err error + + select { + case progress, moreProgress = <-progressCh: + if moreProgress { + progressUpdates = append(progressUpdates, progress) + } + case err, moreErrors = <-errCh: + if moreErrors { + errs = append(errs, err) + } + } + } + + if len(errs) != 1 { + t.Fatalf("expected to receive one error: %v", errs) + } + + err := errs[0] + + if !strings.HasPrefix(err.Error(), "failed to wait for actions") { + t.Fatalf("expected failed to wait for actions error, but got: %#v", err) + } + + expectedProgressUpdates := []int{} + if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { + t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) + } +} + func TestActionClientWatchProgress(t *testing.T) { env := newTestEnv() defer env.Teardown() @@ -385,3 +465,55 @@ func TestActionClientWatchProgressError(t *testing.T) { t.Fatal("expected an error") } } + +func TestActionClientWatchProgressInvalidID(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + switch callCount { + case 1: + _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ + Error: schema.Error{ + Code: string(ErrorCodeNotFound), + Message: "action with ID '1' not found", + Details: nil, + }, + }) + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + }) + action := &Action{ + ID: 1, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) + var ( + progressUpdates []int + err error + ) + +loop: + for { + select { + case progress := <-progressCh: + progressUpdates = append(progressUpdates, progress) + case err = <-errCh: + break loop + } + } + + if !strings.HasPrefix(err.Error(), "failed to wait for action") { + t.Fatalf("expected failed to wait for action error, but got: %#v", err) + } + if len(progressUpdates) != 0 { + t.Fatalf("unexpected progress updates: %v", progressUpdates) + } +}