From 704af1ec61ede25d5f042a2b9db6b7011417ddf0 Mon Sep 17 00:00:00 2001 From: jo Date: Tue, 30 Apr 2024 11:59:40 +0200 Subject: [PATCH] feat: deprecate WatchOverallProgress and WatchProgress function --- hcloud/action_watch.go | 120 ++++++++++++------------------------ hcloud/action_watch_test.go | 83 +++++++++++-------------- 2 files changed, 75 insertions(+), 128 deletions(-) diff --git a/hcloud/action_watch.go b/hcloud/action_watch.go index 5cb0773d5..db3464f11 100644 --- a/hcloud/action_watch.go +++ b/hcloud/action_watch.go @@ -3,7 +3,6 @@ package hcloud import ( "context" "fmt" - "time" ) // WatchOverallProgress watches several actions' progress until they complete @@ -24,6 +23,8 @@ import ( // // WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait // until sending the next request. +// +// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead. func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) { errCh := make(chan error, len(actions)) progressCh := make(chan int) @@ -32,66 +33,37 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti defer close(errCh) defer close(progressCh) - completedIDs := make([]int64, 0, len(actions)) - watchIDs := make(map[int64]struct{}, len(actions)) - for _, action := range actions { - watchIDs[action.ID] = struct{}{} - } - - retries := 0 - previousProgress := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - opts := ActionListOpts{} - for watchID := range watchIDs { - opts.ID = append(opts.ID, watchID) + previousGlobalProgress := 0 + progressByAction := make(map[int64]int, len(actions)) + err := c.WaitForFunc(ctx, func(update *Action) error { + switch update.Status { + case ActionStatusRunning: + progressByAction[update.ID] = update.Progress + case ActionStatusSuccess: + progressByAction[update.ID] = 100 + case ActionStatusError: + progressByAction[update.ID] = 100 + errCh <- fmt.Errorf("action %d failed: %w", update.ID, update.Error()) } - as, err := c.AllWithOpts(ctx, opts) - if err != nil { - errCh <- err - return - } - if len(as) == 0 { - // No actions returned for the provided IDs, they do not exist in the API. - // We need to catch and fail early for this, otherwise the loop will continue - // indefinitely. - errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID) - return + // Compute global progress + progressSum := 0 + for _, value := range progressByAction { + progressSum += value } + globalProgress := progressSum / len(actions) - progress := 0 - for _, a := range as { - switch a.Status { - case ActionStatusRunning: - progress += a.Progress - case ActionStatusSuccess: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - case ActionStatusError: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - errCh <- fmt.Errorf("action %d failed: %w", a.ID, a.Error()) - } + // Only send progress when it changed + if globalProgress != 0 && globalProgress != previousGlobalProgress { + sendProgress(progressCh, globalProgress) + previousGlobalProgress = globalProgress } - progress += len(completedIDs) * 100 - if progress != 0 && progress != previousProgress { - sendProgress(progressCh, progress/len(actions)) - previousProgress = progress - } + return nil + }, actions...) - if len(watchIDs) == 0 { - return - } + if err != nil { + errCh <- err } }() @@ -116,6 +88,8 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti // // WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until // sending the next request. +// +// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead. func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) { errCh := make(chan error, 1) progressCh := make(chan int) @@ -124,38 +98,22 @@ func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-cha defer close(errCh) defer close(progressCh) - retries := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - a, _, err := c.GetByID(ctx, action.ID) - if err != nil { - errCh <- err - return - } - if a == nil { - errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID) - return - } - - switch a.Status { + err := c.WaitForFunc(ctx, func(update *Action) error { + switch update.Status { case ActionStatusRunning: - sendProgress(progressCh, a.Progress) + sendProgress(progressCh, update.Progress) case ActionStatusSuccess: sendProgress(progressCh, 100) - errCh <- nil - return case ActionStatusError: - errCh <- a.Error() - return + // Do not wrap the action error + return update.Error() } + + return nil + }, action) + + if err != nil { + errCh <- err } }() diff --git a/hcloud/action_watch_test.go b/hcloud/action_watch_test.go index 8bed9da73..96eaece5c 100644 --- a/hcloud/action_watch_test.go +++ b/hcloud/action_watch_test.go @@ -6,9 +6,10 @@ import ( "errors" "net/http" "reflect" - "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/hetznercloud/hcloud-go/v2/hcloud/schema" ) @@ -127,7 +128,7 @@ func TestActionClientWatchOverallProgress(t *testing.T) { t.Fatalf("expected hcloud.Error, but got: %#v", err) } - expectedProgressUpdates := []int{50, 100} + expectedProgressUpdates := []int{25, 62, 100} if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) } @@ -202,9 +203,7 @@ func TestActionClientWatchOverallProgressInvalidID(t *testing.T) { err := errs[0] - if !strings.HasPrefix(err.Error(), "failed to wait for actions") { - t.Fatalf("expected failed to wait for actions error, but got: %#v", err) - } + assert.Equal(t, "actions not found: [1]", err.Error()) expectedProgressUpdates := []int{} if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { @@ -218,39 +217,36 @@ func TestActionClientWatchProgress(t *testing.T) { callCount := 0 - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") switch callCount { case 1: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "running", - Progress: 50, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [ + { "id": 1, "status": "running", "progress": 50 } + ], + "meta": { "pagination": { "page": 1 }} + }`)) case 2: w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeConflict), - Message: "conflict", - }, - }) + _, _ = w.Write([]byte(`{ + "error": { + "code": "conflict", + "message": "conflict" + } + }`)) return case 3: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "error", - Progress: 100, - Error: &schema.ActionError{ - Code: "action_failed", - Message: "action failed", - }, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [ + { "id": 1, "status": "error", "progress": 100, "error": { + "code": "action_failed", + "message": "action failed" + } } + ], + "meta": { "pagination": { "page": 1 }} + }`)) default: t.Errorf("unexpected number of calls to the test server: %v", callCount) } @@ -293,7 +289,7 @@ func TestActionClientWatchProgressError(t *testing.T) { env := newTestEnv() defer env.Teardown() - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnprocessableEntity) _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ @@ -304,7 +300,7 @@ func TestActionClientWatchProgressError(t *testing.T) { }) }) - action := &Action{ID: 1} + action := &Action{ID: 1, Status: ActionStatusRunning} ctx := context.Background() _, errCh := env.Client.Action.WatchProgress(ctx, action) if err := <-errCh; err == nil { @@ -318,26 +314,20 @@ func TestActionClientWatchProgressInvalidID(t *testing.T) { callCount := 0 - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) switch callCount { case 1: - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeNotFound), - Message: "action with ID '1' not found", - Details: nil, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [], + "meta": { "pagination": { "page": 1 }} + }`)) default: t.Errorf("unexpected number of calls to the test server: %v", callCount) } }) - action := &Action{ - ID: 1, - } + action := &Action{ID: 1, Status: ActionStatusRunning} ctx := context.Background() progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) @@ -356,9 +346,8 @@ loop: } } - if !strings.HasPrefix(err.Error(), "failed to wait for action") { - t.Fatalf("expected failed to wait for action error, but got: %#v", err) - } + assert.Equal(t, "actions not found: [1]", err.Error()) + if len(progressUpdates) != 0 { t.Fatalf("unexpected progress updates: %v", progressUpdates) }