diff --git a/hcloud/action_waiter.go b/hcloud/action_waiter.go new file mode 100644 index 00000000..ee024c4a --- /dev/null +++ b/hcloud/action_waiter.go @@ -0,0 +1,116 @@ +package hcloud + +import ( + "context" + "fmt" + "maps" + "slices" + "time" +) + +type ActionWaiter interface { + WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error + WaitFor(ctx context.Context, actions ...*Action) error +} + +var _ ActionWaiter = (*ActionClient)(nil) + +// WaitForFunc waits until all actions are completed by polling the API at the interval +// defined by [WithPollBackoffFunc]. An action is considered as complete when its status is +// either [ActionStatusSuccess] or [ActionStatusError]. +// +// The handleUpdate callback is called every time an action is updated. +func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error { + running := make(map[int64]struct{}, len(actions)) + for _, action := range actions { + if action.Status == ActionStatusRunning { + running[action.ID] = struct{}{} + } else if handleUpdate != nil { + // We filter out already completed actions from the API polling loop; while + // this isn't a real update, the caller should be notified about the new + // state. + if err := handleUpdate(action); err != nil { + return err + } + } + } + + retries := 0 + for { + if len(running) == 0 { + break + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(c.action.client.pollBackoffFunc(retries)): + retries++ + } + + opts := ActionListOpts{ + Sort: []string{"status", "id"}, + ID: make([]int64, 0, len(running)), + } + for actionID := range running { + opts.ID = append(opts.ID, actionID) + } + slices.Sort(opts.ID) + + updates, err := c.AllWithOpts(ctx, opts) + if err != nil { + return err + } + + if len(updates) != len(running) { + // Some actions may not exist in the API, also fail early to prevent an + // infinite loop when updates == 0. + + notFound := maps.Clone(running) + for _, update := range updates { + delete(notFound, update.ID) + } + notFoundIDs := make([]int64, 0, len(notFound)) + for unknownID := range notFound { + notFoundIDs = append(notFoundIDs, unknownID) + } + + return fmt.Errorf("actions not found: %v", notFoundIDs) + } + + for _, update := range updates { + if update.Status != ActionStatusRunning { + delete(running, update.ID) + } + + if handleUpdate != nil { + if err := handleUpdate(update); err != nil { + return err + } + } + } + } + + return nil +} + +// WaitFor waits until all actions succeed by polling the API at the interval defined by +// [WithPollBackoffFunc]. An action is considered as succeeded when its status is either +// [ActionStatusSuccess]. +// +// If a single action fails, the function will stop waiting and the error set in the +// action will be returned as an [ActionError]. +// +// For more flexibility, see the [WaitForFunc] function. +func (c *ActionClient) WaitFor(ctx context.Context, actions ...*Action) error { + return c.WaitForFunc( + ctx, + func(update *Action) error { + if update.Status == ActionStatusError { + return update.Error() + } + return nil + }, + actions..., + ) +} diff --git a/hcloud/action_waiter_test.go b/hcloud/action_waiter_test.go new file mode 100644 index 00000000..97183e31 --- /dev/null +++ b/hcloud/action_waiter_test.go @@ -0,0 +1,169 @@ +package hcloud + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWaitFor(t *testing.T) { + RunMockedTestCases(t, + []MockedTestCase{ + { + Name: "succeed", + WantRequests: []MockedRequest{ + {"GET", "/actions?id=1509772237&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772237, "status": "running", "progress": 0 } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + {"GET", "/actions?id=1509772237&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772237, "status": "success", "progress": 100 } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + }, + Run: func(env testEnv) { + actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}} + + err := env.Client.Action.WaitFor(context.Background(), actions...) + assert.NoError(t, err) + }, + }, + { + Name: "succeed with already succeeded action", + Run: func(env testEnv) { + actions := []*Action{{ID: 1509772237, Status: ActionStatusSuccess}} + + err := env.Client.Action.WaitFor(context.Background(), actions...) + assert.NoError(t, err) + }, + }, + { + Name: "fail with unknown action", + WantRequests: []MockedRequest{ + {"GET", "/actions?id=1509772237&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [], + "meta": { "pagination": { "page": 1 }} + }`}, + }, + Run: func(env testEnv) { + actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}} + + err := env.Client.Action.WaitFor(context.Background(), actions...) + assert.Error(t, err) + assert.Equal(t, "actions not found: [1509772237]", err.Error()) + }, + }, + { + Name: "fail with canceled context", + Run: func(env testEnv) { + actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}} + + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + err := env.Client.Action.WaitFor(ctx, actions...) + assert.Error(t, err) + }, + }, + { + Name: "fail with api error", + WantRequests: []MockedRequest{ + {"GET", "/actions?id=1509772237&page=1&sort=status&sort=id", nil, 503, ""}, + }, + Run: func(env testEnv) { + actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}} + + err := env.Client.Action.WaitFor(context.Background(), actions...) + assert.Error(t, err) + assert.Equal(t, "hcloud: server responded with status code 503", err.Error()) + }, + }, + }, + ) +} + +func TestWaitForFunc(t *testing.T) { + RunMockedTestCases(t, + []MockedTestCase{ + { + Name: "succeed", + WantRequests: []MockedRequest{ + {"GET", "/actions?id=1509772237&id=1509772238&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772237, "status": "running", "progress": 40 }, + { "id": 1509772238, "status": "running", "progress": 0 } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + {"GET", "/actions?id=1509772237&id=1509772238&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772237, "status": "running", "progress": 60 }, + { "id": 1509772238, "status": "running", "progress": 50 } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + {"GET", "/actions?id=1509772237&id=1509772238&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772237, "status": "success", "progress": 100 }, + { "id": 1509772238, "status": "running", "progress": 75 } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + {"GET", "/actions?id=1509772238&page=1&sort=status&sort=id", nil, 200, + `{ + "actions": [ + { "id": 1509772238, "status": "error", "progress": 75, + "error": { + "code": "action_failed", + "message": "Something went wrong with the action" + } + } + ], + "meta": { "pagination": { "page": 1 }} + }`}, + }, + Run: func(env testEnv) { + actions := []*Action{ + {ID: 1509772236, Status: ActionStatusSuccess}, + {ID: 1509772237, Status: ActionStatusRunning}, + {ID: 1509772238, Status: ActionStatusRunning}, + } + progress := make([]int, 0) + + progressByAction := make(map[int64]int, len(actions)) + err := env.Client.Action.WaitForFunc(context.Background(), func(update *Action) error { + switch update.Status { + case ActionStatusRunning: + progressByAction[update.ID] = update.Progress + case ActionStatusSuccess: + progressByAction[update.ID] = 100 + case ActionStatusError: + progressByAction[update.ID] = 100 + } + + sum := 0 + for _, value := range progressByAction { + sum += value + } + progress = append(progress, sum/len(actions)) + + return nil + }, actions...) + + assert.Nil(t, err) + assert.Equal(t, []int{33, 46, 46, 53, 70, 83, 91, 100}, progress) + }, + }, + }, + ) +} diff --git a/hcloud/action_watch.go b/hcloud/action_watch.go index 5cb0773d..db3464f1 100644 --- a/hcloud/action_watch.go +++ b/hcloud/action_watch.go @@ -3,7 +3,6 @@ package hcloud import ( "context" "fmt" - "time" ) // WatchOverallProgress watches several actions' progress until they complete @@ -24,6 +23,8 @@ import ( // // WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait // until sending the next request. +// +// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead. func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) { errCh := make(chan error, len(actions)) progressCh := make(chan int) @@ -32,66 +33,37 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti defer close(errCh) defer close(progressCh) - completedIDs := make([]int64, 0, len(actions)) - watchIDs := make(map[int64]struct{}, len(actions)) - for _, action := range actions { - watchIDs[action.ID] = struct{}{} - } - - retries := 0 - previousProgress := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - opts := ActionListOpts{} - for watchID := range watchIDs { - opts.ID = append(opts.ID, watchID) + previousGlobalProgress := 0 + progressByAction := make(map[int64]int, len(actions)) + err := c.WaitForFunc(ctx, func(update *Action) error { + switch update.Status { + case ActionStatusRunning: + progressByAction[update.ID] = update.Progress + case ActionStatusSuccess: + progressByAction[update.ID] = 100 + case ActionStatusError: + progressByAction[update.ID] = 100 + errCh <- fmt.Errorf("action %d failed: %w", update.ID, update.Error()) } - as, err := c.AllWithOpts(ctx, opts) - if err != nil { - errCh <- err - return - } - if len(as) == 0 { - // No actions returned for the provided IDs, they do not exist in the API. - // We need to catch and fail early for this, otherwise the loop will continue - // indefinitely. - errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID) - return + // Compute global progress + progressSum := 0 + for _, value := range progressByAction { + progressSum += value } + globalProgress := progressSum / len(actions) - progress := 0 - for _, a := range as { - switch a.Status { - case ActionStatusRunning: - progress += a.Progress - case ActionStatusSuccess: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - case ActionStatusError: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - errCh <- fmt.Errorf("action %d failed: %w", a.ID, a.Error()) - } + // Only send progress when it changed + if globalProgress != 0 && globalProgress != previousGlobalProgress { + sendProgress(progressCh, globalProgress) + previousGlobalProgress = globalProgress } - progress += len(completedIDs) * 100 - if progress != 0 && progress != previousProgress { - sendProgress(progressCh, progress/len(actions)) - previousProgress = progress - } + return nil + }, actions...) - if len(watchIDs) == 0 { - return - } + if err != nil { + errCh <- err } }() @@ -116,6 +88,8 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti // // WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until // sending the next request. +// +// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead. func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) { errCh := make(chan error, 1) progressCh := make(chan int) @@ -124,38 +98,22 @@ func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-cha defer close(errCh) defer close(progressCh) - retries := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - a, _, err := c.GetByID(ctx, action.ID) - if err != nil { - errCh <- err - return - } - if a == nil { - errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID) - return - } - - switch a.Status { + err := c.WaitForFunc(ctx, func(update *Action) error { + switch update.Status { case ActionStatusRunning: - sendProgress(progressCh, a.Progress) + sendProgress(progressCh, update.Progress) case ActionStatusSuccess: sendProgress(progressCh, 100) - errCh <- nil - return case ActionStatusError: - errCh <- a.Error() - return + // Do not wrap the action error + return update.Error() } + + return nil + }, action) + + if err != nil { + errCh <- err } }() diff --git a/hcloud/action_watch_test.go b/hcloud/action_watch_test.go index 8bed9da7..96eaece5 100644 --- a/hcloud/action_watch_test.go +++ b/hcloud/action_watch_test.go @@ -6,9 +6,10 @@ import ( "errors" "net/http" "reflect" - "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/hetznercloud/hcloud-go/v2/hcloud/schema" ) @@ -127,7 +128,7 @@ func TestActionClientWatchOverallProgress(t *testing.T) { t.Fatalf("expected hcloud.Error, but got: %#v", err) } - expectedProgressUpdates := []int{50, 100} + expectedProgressUpdates := []int{25, 62, 100} if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) } @@ -202,9 +203,7 @@ func TestActionClientWatchOverallProgressInvalidID(t *testing.T) { err := errs[0] - if !strings.HasPrefix(err.Error(), "failed to wait for actions") { - t.Fatalf("expected failed to wait for actions error, but got: %#v", err) - } + assert.Equal(t, "actions not found: [1]", err.Error()) expectedProgressUpdates := []int{} if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { @@ -218,39 +217,36 @@ func TestActionClientWatchProgress(t *testing.T) { callCount := 0 - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") switch callCount { case 1: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "running", - Progress: 50, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [ + { "id": 1, "status": "running", "progress": 50 } + ], + "meta": { "pagination": { "page": 1 }} + }`)) case 2: w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeConflict), - Message: "conflict", - }, - }) + _, _ = w.Write([]byte(`{ + "error": { + "code": "conflict", + "message": "conflict" + } + }`)) return case 3: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "error", - Progress: 100, - Error: &schema.ActionError{ - Code: "action_failed", - Message: "action failed", - }, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [ + { "id": 1, "status": "error", "progress": 100, "error": { + "code": "action_failed", + "message": "action failed" + } } + ], + "meta": { "pagination": { "page": 1 }} + }`)) default: t.Errorf("unexpected number of calls to the test server: %v", callCount) } @@ -293,7 +289,7 @@ func TestActionClientWatchProgressError(t *testing.T) { env := newTestEnv() defer env.Teardown() - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnprocessableEntity) _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ @@ -304,7 +300,7 @@ func TestActionClientWatchProgressError(t *testing.T) { }) }) - action := &Action{ID: 1} + action := &Action{ID: 1, Status: ActionStatusRunning} ctx := context.Background() _, errCh := env.Client.Action.WatchProgress(ctx, action) if err := <-errCh; err == nil { @@ -318,26 +314,20 @@ func TestActionClientWatchProgressInvalidID(t *testing.T) { callCount := 0 - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { callCount++ w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) switch callCount { case 1: - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeNotFound), - Message: "action with ID '1' not found", - Details: nil, - }, - }) + _, _ = w.Write([]byte(`{ + "actions": [], + "meta": { "pagination": { "page": 1 }} + }`)) default: t.Errorf("unexpected number of calls to the test server: %v", callCount) } }) - action := &Action{ - ID: 1, - } + action := &Action{ID: 1, Status: ActionStatusRunning} ctx := context.Background() progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) @@ -356,9 +346,8 @@ loop: } } - if !strings.HasPrefix(err.Error(), "failed to wait for action") { - t.Fatalf("expected failed to wait for action error, but got: %#v", err) - } + assert.Equal(t, "actions not found: [1]", err.Error()) + if len(progressUpdates) != 0 { t.Fatalf("unexpected progress updates: %v", progressUpdates) } diff --git a/hcloud/client_test.go b/hcloud/client_test.go index 49d20232..92f478de 100644 --- a/hcloud/client_test.go +++ b/hcloud/client_test.go @@ -31,6 +31,10 @@ func (env *testEnv) Teardown() { func newTestEnv() testEnv { mux := http.NewServeMux() server := httptest.NewServer(mux) + return newTestEnvWithServer(server, mux) +} + +func newTestEnvWithServer(server *httptest.Server, mux *http.ServeMux) testEnv { client := NewClient( WithEndpoint(server.URL), WithToken("token"), diff --git a/hcloud/interface_gen.go b/hcloud/interface_gen.go index 7074ca9a..2ae8cecb 100644 --- a/hcloud/interface_gen.go +++ b/hcloud/interface_gen.go @@ -1,6 +1,6 @@ package hcloud -//go:generate go run github.com/vburenin/ifacemaker -f action.go -f action_watch.go -s ActionClient -i IActionClient -p hcloud -o zz_action_client_iface.go +//go:generate go run github.com/vburenin/ifacemaker -f action.go -f action_watch.go -f action_waiter.go -s ActionClient -i IActionClient -p hcloud -o zz_action_client_iface.go //go:generate go run github.com/vburenin/ifacemaker -f action.go -s ResourceActionClient -i IResourceActionClient -p hcloud -o zz_resource_action_client_iface.go //go:generate go run github.com/vburenin/ifacemaker -f datacenter.go -s DatacenterClient -i IDatacenterClient -p hcloud -o zz_datacenter_client_iface.go //go:generate go run github.com/vburenin/ifacemaker -f floating_ip.go -s FloatingIPClient -i IFloatingIPClient -p hcloud -o zz_floating_ip_client_iface.go diff --git a/hcloud/mocked_test.go b/hcloud/mocked_test.go new file mode 100644 index 00000000..0799a1f2 --- /dev/null +++ b/hcloud/mocked_test.go @@ -0,0 +1,75 @@ +package hcloud + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type MockedTestCase struct { + Name string + WantRequests []MockedRequest + Run func(env testEnv) +} + +type MockedRequest struct { + Method string + Path string + WantRequestBodyFunc func(t *testing.T, r *http.Request, body []byte) + + Status int + Body string +} + +func RunMockedTestCases(t *testing.T, testCases []MockedTestCase) { + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + env := newTestEnvWithServer(httptest.NewServer(MockedRequestHandler(t, testCase.WantRequests)), nil) + defer env.Teardown() + + testCase.Run(env) + }) + } +} + +func MockedRequestHandler(t *testing.T, requests []MockedRequest) http.HandlerFunc { + index := 0 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if testing.Verbose() { + t.Logf("request %d: %s %s\n", index, r.Method, r.URL.Path) + } + + if index >= len(requests) { + t.Fatalf("received unknown request %d", index) + } + + response := requests[index] + assert.Equal(t, response.Method, r.Method) + assert.Equal(t, response.Path, r.RequestURI) + + if response.WantRequestBodyFunc != nil { + buffer, err := io.ReadAll(r.Body) + defer func() { + if err := r.Body.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatal(err) + } + response.WantRequestBodyFunc(t, r, buffer) + } + + w.WriteHeader(response.Status) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(response.Body)) + if err != nil { + t.Fatal(err) + } + + index++ + }) +} diff --git a/hcloud/zz_action_client_iface.go b/hcloud/zz_action_client_iface.go index 1917574a..c9b584cb 100644 --- a/hcloud/zz_action_client_iface.go +++ b/hcloud/zz_action_client_iface.go @@ -37,6 +37,8 @@ type IActionClient interface { // // WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait // until sending the next request. + // + // Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead. WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) // WatchProgress watches one action's progress until it completes with success // or error. This watching happens in a goroutine and updates are provided @@ -56,5 +58,22 @@ type IActionClient interface { // // WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until // sending the next request. + // + // Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead. WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) + // WaitForFunc waits until all actions are completed by polling the API at the interval + // defined by [WithPollBackoffFunc]. An action is considered as complete when its status is + // either [ActionStatusSuccess] or [ActionStatusError]. + // + // The handleUpdate callback is called every time an action is updated. + WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error + // WaitFor waits until all actions succeed by polling the API at the interval defined by + // [WithPollBackoffFunc]. An action is considered as succeeded when its status is either + // [ActionStatusSuccess]. + // + // If a single action fails, the function will stop waiting and the error set in the + // action will be returned as an [ActionError]. + // + // For more flexibility, see the [WaitForFunc] function. + WaitFor(ctx context.Context, actions ...*Action) error }