From f54f4fcad032b3e1c394d073d91baa4854ba0cb5 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Mon, 7 Jun 2021 15:20:13 -0700 Subject: [PATCH 01/21] LRO poller rewrite Simplified implementation of LRO pollers for ARM. Public surface area has been slightly changed, making it identical to the data-plane implementation. The different polling mechanisms have been split into internal packages, with an exported LROPoller that implements the overall polling algorithm. --- sdk/armcore/go.mod | 4 +- sdk/armcore/go.sum | 9 +- sdk/armcore/internal/pollers/async/async.go | 109 ++ .../internal/pollers/async/async_test.go | 161 +++ sdk/armcore/internal/pollers/body/body.go | 78 ++ .../internal/pollers/body/body_test.go | 130 +++ sdk/armcore/internal/pollers/loc/loc.go | 67 ++ sdk/armcore/internal/pollers/loc/loc_test.go | 86 ++ sdk/armcore/internal/pollers/pollers.go | 120 ++ sdk/armcore/internal/pollers/pollers_test.go | 111 ++ sdk/armcore/poller.go | 1030 +++-------------- sdk/armcore/poller_test.go | 264 +++-- sdk/armcore/version.go | 2 +- 13 files changed, 1210 insertions(+), 961 deletions(-) create mode 100644 sdk/armcore/internal/pollers/async/async.go create mode 100644 sdk/armcore/internal/pollers/async/async_test.go create mode 100644 sdk/armcore/internal/pollers/body/body.go create mode 100644 sdk/armcore/internal/pollers/body/body_test.go create mode 100644 sdk/armcore/internal/pollers/loc/loc.go create mode 100644 sdk/armcore/internal/pollers/loc/loc_test.go create mode 100644 sdk/armcore/internal/pollers/pollers.go create mode 100644 sdk/armcore/internal/pollers/pollers_test.go diff --git a/sdk/armcore/go.mod b/sdk/armcore/go.mod index 9a8cf2842281..91c7df6b1aeb 100644 --- a/sdk/armcore/go.mod +++ b/sdk/armcore/go.mod @@ -3,6 +3,6 @@ module github.com/Azure/azure-sdk-for-go/sdk/armcore go 1.14 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0 - github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2 + github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 ) diff --git a/sdk/armcore/go.sum b/sdk/armcore/go.sum index 91b3874be2ef..50a97bab95d6 100644 --- a/sdk/armcore/go.sum +++ b/sdk/armcore/go.sum @@ -1,7 +1,7 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0 h1:4HBTI/9UDZN7tsXyB5TYP3xCv5xVHIUTbvHHH2HFxQY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0/go.mod h1:pElNP+u99BvCZD+0jOlhI9OC/NB2IDTOTGZOZH0Qhq8= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0 h1:HG1ggl8L3ZkV/Ydanf7lKr5kkhhPGCpWdnr1J6v7cO4= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2 h1:UC4vfOhW2l0f2QOCQpOxJS4/K6oKFy2tQZE+uWU1MEo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 h1:vx8McI56N5oLSQu8xa+xdiE0fjQq8W8Zt49vHP8Rygw= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -10,7 +10,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go new file mode 100644 index 000000000000..1d9aaacdb111 --- /dev/null +++ b/sdk/armcore/internal/pollers/async/async.go @@ -0,0 +1,109 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + finalStateAsync = "azure-async-operation" + finalStateLoc = "location" + finalStateOrig = "original-uri" +) + +// Applicable returns true if the LRO is using Azure-AsyncOperation. +func Applicable(resp *azcore.Response) bool { + return resp.Header.Get(pollers.HeaderAzureAsync) != "" +} + +// Poller is an LRO poller that uses the Azure-AsyncOperation pattern. +type Poller struct { + Type string `json:"type"` + AsyncURL string `json:"asyncURL"` + LocURL string `json:"locURL"` + OrigURL string `json:"origURL"` + Method string `json:"method"` + FinalState string `json:"finalState"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response and final-state type. +func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Azure-AsyncOperation poller.") + asyncURL := resp.Header.Get(pollers.HeaderAzureAsync) + if asyncURL == "" { + return nil, errors.New("response is missing Azure-AsyncOperation header") + } + p := &Poller{ + Type: pollers.MakeID(pollerID, "async"), + AsyncURL: asyncURL, + LocURL: resp.Header.Get(pollers.HeaderLocation), + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + } + // check for provisioning state + state, err := pollers.GetProvisioningState(resp) + if errors.Is(err, pollers.ErrNoProvisioningState) { + if resp.Request.Method == http.MethodPut { + // initial response for a PUT requires a provisioning state + return nil, err + } + // for DELETE/PATCH/POST, provisioning state is optional + state = "InProgress" + } else if err != nil { + return nil, err + } + p.CurState = state + return p, nil +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + state, err := pollers.GetStatus(resp) + if err != nil { + return err + } + p.CurState = state + return nil +} + +// FinalGetURL returns the URL to perform a final GET for the payload, or the empty string if not required. +func (p *Poller) FinalGetURL() string { + if p.Method == http.MethodPatch || p.Method == http.MethodPut { + // for PATCH and PUT, the final GET is on the original resource URL + return p.OrigURL + } else if p.Method == http.MethodPost { + // for POST, we need to consult the final-state-via flag + if p.FinalState == finalStateLoc && p.LocURL != "" { + return p.LocURL + } else if p.FinalState == finalStateOrig { + return p.OrigURL + } + // finalStateAsync fall through + } + return "" +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.AsyncURL +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/async/async_test.go b/sdk/armcore/internal/pollers/async/async_test.go new file mode 100644 index 000000000000..30b88a4e2f5d --- /dev/null +++ b/sdk/armcore/internal/pollers/async/async_test.go @@ -0,0 +1,161 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + Request: req, + }, + } +} + +func pollingResponse(resp io.Reader) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + }, + } + if Applicable(&resp) { + t.Fatal("missing Azure-AsyncOperation should not be applicable") + } + resp.Response.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + if !Applicable(&resp) { + t.Fatal("having Azure-AsyncOperation should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(strings.NewReader(`{ "status": "InProgress" }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewDeleteNoProvState(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewFail(t *testing.T) { + // missing provisioning state on initial response + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} + +func TestNewFinalGetLocation(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(pollers.HeaderLocation, locURL) + poller, err := New(resp, "location", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != locURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFinalGetOrigin(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(pollers.HeaderLocation, locURL) + poller, err := New(resp, "original-uri", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go new file mode 100644 index 000000000000..331ac63d4aa1 --- /dev/null +++ b/sdk/armcore/internal/pollers/body/body.go @@ -0,0 +1,78 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// Applicable returns true if the LRO is using no headers, just provisioning state. +// This is only applicable to PATCH and PUT methods and assumes no polling headers. +func Applicable(resp *azcore.Response) bool { + // we can't check for absense of headers due to some misbehaving services + // like redis that return a Location header but don't actually use that protocol + return resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut +} + +// Poller is an LRO poller that uses the Body pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *azcore.Response, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Body poller.") + p := &Poller{ + Type: pollers.MakeID(pollerID, "body"), + PollURL: resp.Request.URL.String(), + } + // the initial response must contain a provisioning state + state, err := pollers.GetProvisioningState(resp) + if err != nil { + return nil, err + } + p.CurState = state + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + state, err := pollers.GetProvisioningState(resp) + if errors.Is(err, pollers.ErrNoProvisioningState) { + // absense of any provisioning state is considered terminal success + state = "Succeeded" + } else if err != nil { + return err + } + p.CurState = state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (*Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/body/body_test.go b/sdk/armcore/internal/pollers/body/body_test.go new file mode 100644 index 000000000000..d67a74f02a0f --- /dev/null +++ b/sdk/armcore/internal/pollers/body/body_test.go @@ -0,0 +1,130 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + Request: req, + }, + } +} + +func pollingResponse(resp io.Reader) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + Request: &http.Request{ + Method: http.MethodDelete, + }, + }, + } + if Applicable(&resp) { + t.Fatal("method DELETE should not be applicable") + } + resp.Request.Method = http.MethodPatch + if !Applicable(&resp) { + t.Fatal("method PATCH should be applicable") + } + resp.Request.Method = http.MethodPut + if !Applicable(&resp) { + t.Fatal("method PUT should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateNoProvState(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewFail(t *testing.T) { + // missing provisioning state on initial response + resp := initialResponse(http.MethodPut, http.NoBody) + poller, err := New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..1b941dd08254 --- /dev/null +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -0,0 +1,67 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *azcore.Response) bool { + return resp.StatusCode == http.StatusAccepted && resp.Header.Get(pollers.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *azcore.Response, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Location poller.") + p := &Poller{ + Type: pollers.MakeID(pollerID, "loc"), + PollURL: resp.Header.Get(pollers.HeaderLocation), + CurState: "InProgress", + } + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + // any 2xx code other than 202 indicates success + if resp.HasStatusCode(http.StatusOK, http.StatusCreated, http.StatusNoContent) { + p.CurState = "Succeeded" + } else if resp.StatusCode > 399 && resp.StatusCode < 500 { + p.CurState = "Failed" + } + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (p *Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/loc/loc_test.go b/sdk/armcore/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..bc205c6bb487 --- /dev/null +++ b/sdk/armcore/internal/pollers/loc/loc_test.go @@ -0,0 +1,86 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + }, + } +} + +func pollingResponse(statusCode int) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + StatusCode: statusCode, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + }, + } + if Applicable(&resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Response.Header.Set(pollers.HeaderLocation, fakePollingURL) + if !Applicable(&resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(pollers.HeaderLocation, fakePollingURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusNoContent)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusConflict)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Failed" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go new file mode 100644 index 000000000000..0ee2c0c0e966 --- /dev/null +++ b/sdk/armcore/internal/pollers/pollers.go @@ -0,0 +1,120 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderLocation = "Location" +) + +// reads the response body into a raw JSON object. +// returns an empty object if there was no content. +func getJSON(resp *azcore.Response) (map[string]interface{}, error) { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + resp.Body.Close() + if len(body) == 0 { + return map[string]interface{}{}, nil + } + // put the body back so it's available to others + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + // unmarshall the body to get the value + var jsonBody map[string]interface{} + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + +// provisioningState returns the provisioning state from the response or the empty string. +func provisioningState(jsonBody map[string]interface{}) string { + jsonProps, ok := jsonBody["properties"] + if !ok { + return "" + } + props, ok := jsonProps.(map[string]interface{}) + if !ok { + return "" + } + rawPs, ok := props["provisioningState"] + if !ok { + return "" + } + ps, ok := rawPs.(string) + if !ok { + return "" + } + return ps +} + +// status returns the status from the response or the empty string. +func status(jsonBody map[string]interface{}) string { + rawStatus, ok := jsonBody["status"] + if !ok { + return "" + } + status, ok := rawStatus.(string) + if !ok { + return "" + } + return status +} + +// IsTerminalState returns true if the LRO's state is terminal. +func IsTerminalState(s string) bool { + return strings.EqualFold(s, "succeeded") || strings.EqualFold(s, "failed") || strings.EqualFold(s, "canceled") +} + +// GetStatus returns the LRO's status from the response body. +// Typically used for Azure-AsyncOperation flows. +// If there is no status in the response body ErrNoStatus is returned. +func GetStatus(resp *azcore.Response) (string, error) { + jsonBody, err := getJSON(resp) + if err != nil { + return "", err + } + if s := status(jsonBody); s != "" { + return s, nil + } + return "", ErrNoStatus +} + +// GetProvisioningState returns the LRO's state from the response body. +// If there is no state in the response body ErrNoProvisioningState is returned. +func GetProvisioningState(resp *azcore.Response) (string, error) { + jsonBody, err := getJSON(resp) + if err != nil { + return "", err + } + if ps := provisioningState(jsonBody); ps != "" { + return ps, nil + } + return "", ErrNoProvisioningState +} + +// MakeID returns the unique poller identifier in the format pollerID;poller. +func MakeID(pollerID string, kind string) string { + return fmt.Sprintf("%s;%s", pollerID, kind) +} + +// ErrNoStatus is returned if the response body didn't contain a status. +var ErrNoStatus = errors.New("the response did not contain a status") + +// ErrNoProvisioningState is returned if the response body didn't contain a provisioning state. +var ErrNoProvisioningState = errors.New("the response did not contain a provisioning state") diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go new file mode 100644 index 000000000000..eb31235a84eb --- /dev/null +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -0,0 +1,111 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +func TestIsTerminalState(t *testing.T) { + if IsTerminalState("Updating") { + t.Fatal("Updating is not a terminal state") + } + if !IsTerminalState("Succeeded") { + t.Fatal("Succeeded is a terminal state") + } + if !IsTerminalState("failed") { + t.Fatal("failed is a terminal state") + } + if !IsTerminalState("canceled") { + t.Fatal("canceled is a terminal state") + } +} + +func TestGetStatusSuccess(t *testing.T) { + const jsonBody = `{ "status": "InProgress" }` + resp := azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(strings.NewReader(jsonBody)), + }, + } + status, err := GetStatus(&resp) + if err != nil { + t.Fatal(err) + } + if status != "InProgress" { + t.Fatalf("unexpected status %s", status) + } +} + +func TestGetStatusError(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: http.NoBody, + }, + } + status, err := GetStatus(&resp) + if !errors.Is(err, ErrNoStatus) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatalf("expected empty status, got %s", status) + } +} + +func TestGetProvisioningState(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Canceled" } }` + resp := azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(strings.NewReader(jsonBody)), + }, + } + state, err := GetProvisioningState(&resp) + if err != nil { + t.Fatal(err) + } + if state != "Canceled" { + t.Fatalf("unexpected status %s", state) + } +} + +func TestGetProvisioningStateError(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: http.NoBody, + }, + } + state, err := GetProvisioningState(&resp) + if !errors.Is(err, ErrNoProvisioningState) { + t.Fatalf("unexpected error %T", err) + } + if state != "" { + t.Fatalf("expected empty provisioning state, got %s", state) + } +} + +func TestMakeID(t *testing.T) { + const ( + pollerID = "pollerID" + kind = "kind" + ) + id := MakeID(pollerID, kind) + parts := strings.Split(id, ";") + if l := len(parts); l != 2 { + t.Fatalf("unexpected length %d", l) + } + if p := parts[0]; p != pollerID { + t.Fatalf("unexpected poller ID %s", p) + } + if p := parts[1]; p != kind { + t.Fatalf("unexpected poller kind %s", p) + } +} diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index bbffe09c5c69..73c75692b07c 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -6,212 +6,215 @@ package armcore import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" "io/ioutil" "net/http" - "net/url" - "reflect" "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) -const ( - headerAsyncOperation = "Azure-AsyncOperation" - headerLocation = "Location" -) - -const ( - operationInProgress string = "InProgress" - operationCanceled string = "Canceled" - operationFailed string = "Failed" - operationSucceeded string = "Succeeded" -) +// ErrorUnmarshaller is the func to invoke when the endpoint returns an error response that requires unmarshalling. +type ErrorUnmarshaller func(*azcore.Response) error -var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK} - -// NewPoller creates a polling tracker based on the verb of the original request and returns -// the polling tracker implementation for the method verb or an error. +// NewLROPoller creates an LROPoller based on the provided initial response. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPoller(pollerType string, finalState string, resp *azcore.Response, errorHandler methodErrorHandler) (Poller, error) { - var pt pollingTracker - switch strings.ToUpper(resp.Request.Method) { - case http.MethodDelete: - pt = &pollingTrackerDelete{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPatch: - pt = &pollingTrackerPatch{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPost: - pt = &pollingTrackerPost{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPut: - pt = &pollingTrackerPut{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - default: - return nil, fmt.Errorf("unsupported HTTP method %s", resp.Request.Method) - } - if err := pt.initializeState(); err != nil { - return nil, err +func NewLROPoller(pollerID string, finalState string, resp *azcore.Response, pl azcore.Pipeline, eu ErrorUnmarshaller) (*LROPoller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !lroStatusCodeValid(resp) { + return nil, errors.New("the LRO failed or was cancelled") + } + // determine the polling method + var lro lroPoller + var err error + if async.Applicable(resp) { + lro, err = async.New(resp, finalState, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else if body.Applicable(resp) { + // must test body poller last as it's a subset of the other pollers. + // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) + lro, err = body.New(resp, pollerID) + } else { + lro = &nopPoller{} } - // this initializes the polling header values, we do this during creation in case the - // initial response send us invalid values; this way the API call will return a non-nil - // error (not doing this means the error shows up in Future.Done) - if err := pt.updatePollingMethod(); err != nil { + if err != nil { return nil, err } - return &poller{pt: pt}, nil + return &LROPoller{lro: lro, pl: pl, eu: eu, resp: resp}, nil } -// NewPollerFromResumeToken creates a polling tracker from a resume token string. +// NewLROPollerFromResumeToken creates an LROPoller from a resume token string. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPollerFromResumeToken(pollerType string, token string, errorHandler methodErrorHandler) (Poller, error) { - // unmarshal into JSON object to determine the tracker type +func NewLROPollerFromResumeToken(pollerID string, token string, pl azcore.Pipeline, eu ErrorUnmarshaller) (*LROPoller, error) { + // unmarshal into JSON object to determine the poller type obj := map[string]interface{}{} err := json.Unmarshal([]byte(token), &obj) if err != nil { return nil, err } - if obj["pollerType"] != pollerType { - return nil, fmt.Errorf("Cannot resume from this poller type. Expected: %s, Received: %s", pollerType, obj["pollerType"]) - } - if obj["method"] == nil { - return nil, fmt.Errorf("Token is missing 'method' property") - } - var pt pollingTracker - switch method := strings.ToUpper(obj["method"].(string)); method { - case http.MethodDelete: - pt = &pollingTrackerDelete{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPatch: - pt = &pollingTrackerPatch{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPost: - pt = &pollingTrackerPost{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPut: - pt = &pollingTrackerPut{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} + t, ok := obj["type"] + if !ok { + return nil, errors.New("missing type field") + } + tt, ok := t.(string) + if !ok { + return nil, fmt.Errorf("invalid type format %T", t) + } + // the type is encoded as "pollerID;pollerType" + sem := strings.LastIndex(tt, ";") + if sem < 0 { + return nil, fmt.Errorf("invalid poller type %s", tt) + } + // ensure poller types match + if received := tt[:sem]; received != pollerID { + return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, received) + } + // now rehydrate the poller based on the encoded poller type + var lro lroPoller + switch pt := tt[sem+1:]; pt { + case "async": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Azure-AsyncOperation poller.") + lro = &async.Poller{} + case "loc": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Location poller.") + lro = &loc.Poller{} + case "body": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Body poller.") + lro = &body.Poller{} default: - return nil, fmt.Errorf("unsupported method '%s'", method) + return nil, fmt.Errorf("unhandled poller type %s", pt) } - // now unmarshal into the tracker - err = json.Unmarshal([]byte(token), &pt) - if err != nil { + if err = json.Unmarshal([]byte(token), lro); err != nil { return nil, err } - return &poller{pt: pt}, nil + return &LROPoller{lro: lro, pl: pl, eu: eu}, nil } -// Poller defines the methods that will be called internally in the generated code for long-running operations. +// LROPoller encapsulates state and logic for polling on long-running operations. // NOTE: this is only meant for internal use in generated code. -type Poller interface { - // Done signals if the polling operation has reached a terminal state. - Done() bool - // Poll sends a polling request to the service endpoint and returns the http.Response received from the endpoint or an error. - Poll(ctx context.Context, p azcore.Pipeline) (*http.Response, error) - // FinalResponse will perform a final GET and return the final http response for the polling operation and unmarshal the content of the payload into the respType interface that is provided. - FinalResponse(ctx context.Context, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) - // ResumeToken returns a token string that can be used to resume polling on a poller that has not yet reached a terminal state. - ResumeToken() (string, error) - // PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, then return the final http response for the polling operation and unmarshal the content of the payload into the respType interface that is provided. - PollUntilDone(ctx context.Context, frequency time.Duration, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) -} - -type poller struct { - pt pollingTracker +type LROPoller struct { + lro lroPoller + pl azcore.Pipeline + eu ErrorUnmarshaller + resp *azcore.Response + err error } -func (p *poller) Done() bool { - return p.pt.hasTerminated() +// Done returns true if the LRO has reached a terminal state. +func (l *LROPoller) Done() bool { + if l.err != nil { + return true + } + return l.lro.Done() } -func (p *poller) FinalResponse(ctx context.Context, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) { - if !p.pt.hasTerminated() { - return nil, errors.New("cannot return a final response from a poller in a non-terminal state") - } - // if respType is nil, this indicates that the request was made from an HTTPPoller - if respType == nil { - return p.pt.latestResponse().Response, nil - } - if p.pt.pollerMethodVerb() == http.MethodPut || p.pt.pollerMethodVerb() == http.MethodPatch { - res, err := p.handleResponse(p.pt.latestResponse(), respType) - if err != nil { - return nil, err - } - if res != nil && !reflect.Indirect(reflect.ValueOf(respType)).IsZero() { - return res, nil +// Poll sends a polling request to the polling endpoint and returns the response or error. +func (l *LROPoller) Poll(ctx context.Context) (*http.Response, error) { + if l.Done() { + // the LRO has reached a terminal state, don't poll again + if l.resp != nil { + return l.resp.Response, nil } + return nil, l.err } - azcore.Log().Writef(azcore.LogLongRunningOperation, "performing final GET for %s", p.pt.pollerType()) - // checking if there was a FinalStateVia configuration to re-route the final GET - // request to the value specified in the FinalStateVia property on the poller - err := p.pt.setFinalState() + req, err := azcore.NewRequest(ctx, http.MethodGet, l.lro.URL()) if err != nil { return nil, err } - if p.pt.finalGetURL() == "" { - // we can end up in this situation if the async operation returns a 200 - // with no polling URLs. in that case return the response which should - // contain the JSON payload (only do this for successful terminal cases). - if lr := p.pt.latestResponse(); lr != nil && p.pt.hasSucceeded() { - result, err := p.handleResponse(lr, respType) - if err != nil { - return nil, err - } - return result, nil - } - return nil, errors.New("missing URL for retrieving result") - } - req, err := azcore.NewRequest(ctx, http.MethodGet, p.pt.finalGetURL()) + resp, err := l.pl.Do(req) if err != nil { + // don't update the poller for failed requests return nil, err } - resp, err := pipeline.Do(req) - if err != nil { + defer resp.Body.Close() + if !lroStatusCodeValid(resp) { + // the LRO failed. unmarshall the error and update state + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + if err = l.lro.Update(resp); err != nil { return nil, err } - return p.handleResponse(resp, respType) + l.resp = resp + return l.resp.Response, nil } -func (p *poller) ResumeToken() (string, error) { - if p.pt.hasTerminated() { +// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. +func (l *LROPoller) ResumeToken() (string, error) { + if l.Done() { return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") } - js, err := json.Marshal(p.pt) + b, err := json.Marshal(l.lro) if err != nil { return "", err } - return string(js), nil + return string(b), nil } -func (p *poller) handleResponse(resp *azcore.Response, respType interface{}) (*http.Response, error) { - if resp.HasStatusCode(http.StatusNoContent) { - return resp.Response, nil +// FinalResponse will perform a final GET request and return the final HTTP response for the polling +// operation and unmarshall the content of the payload into the respType interface that is provided. +func (l *LROPoller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { + if !l.Done() { + return nil, errors.New("cannot return a final response from a poller in a non-terminal state") } - if !resp.HasStatusCode(pollingCodes[:]...) { - return nil, p.pt.handleError(resp) + // if there's nothing to unmarshall into just return the final response + if respType == nil { + return l.resp.Response, nil } - return resp.Response, resp.UnmarshalAsJSON(respType) -} - -func (p *poller) Poll(ctx context.Context, pipeline azcore.Pipeline) (*http.Response, error) { - if p.pollForStatus(ctx, pipeline, p.pt) { - return p.pt.latestResponse().Response, p.pt.pollingError() + if u := l.lro.FinalGetURL(); u != "" { + azcore.Log().Write(azcore.LogLongRunningOperation, "Performing final GET.") + req, err := azcore.NewRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + resp, err := l.pl.Do(req) + if err != nil { + return nil, err + } + if !lroStatusCodeValid(resp) { + return nil, l.eu(resp) + } + l.resp = resp + } + body, err := ioutil.ReadAll(l.resp.Body) + l.resp.Body.Close() + if err != nil { + return nil, err + } + if err = json.Unmarshal(body, respType); err != nil { + return nil, err } - return nil, p.pt.pollingError() + return l.resp.Response, nil } -func (p *poller) PollUntilDone(ctx context.Context, frequency time.Duration, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) { +// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, +// then return the final HTTP response for the polling operation and unmarshal the content of the payload +// into the respType interface that is provided. +// freq - the time to wait between polling intervals if the endpoint doesn't send a Retry-After header. +// A good starting value is 30 seconds. Note that some resources might benefit from a different value. +func (l *LROPoller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { + start := time.Now() logPollUntilDoneExit := func(v interface{}) { - azcore.Log().Writef(azcore.LogLongRunningOperation, "END PollUntilDone() for %s: %v", p.pt.pollerType(), v) + azcore.Log().Writef(azcore.LogLongRunningOperation, "END PollUntilDone() for %T: %v, total time: %s", l.lro, v, time.Now().Sub(start)) } - azcore.Log().Writef(azcore.LogLongRunningOperation, "BEGIN PollUntilDone() for %s", p.pt.pollerType()) - // p.resp should only be nil when calling PollUntilDone from a poller that was instantiated from a resume token string - if resp := p.pt.latestResponse(); resp != nil { + azcore.Log().Writef(azcore.LogLongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) + if l.resp != nil { // initial check for a retry-after header existing on the initial response - if retryAfter := azcore.RetryAfter(resp.Response); retryAfter > 0 { + if retryAfter := azcore.RetryAfter(l.resp.Response); retryAfter > 0 { azcore.Log().Writef(azcore.LogLongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) - err := delay(ctx, retryAfter) - if err != nil { + if err := delay(ctx, retryAfter); err != nil { logPollUntilDoneExit(err) return nil, err } @@ -219,745 +222,84 @@ func (p *poller) PollUntilDone(ctx context.Context, frequency time.Duration, pip } // begin polling the endpoint until a terminal state is reached for { - resp, err := p.Poll(ctx, pipeline) + resp, err := l.Poll(ctx) if err != nil { logPollUntilDoneExit(err) return nil, err } - if p.pt.hasTerminated() { - break + if l.Done() { + status := l.lro.Status() + azcore.Log().Writef(azcore.LogLongRunningOperation, "Status %s", status) + logPollUntilDoneExit(status) + if !strings.EqualFold(status, "succeeded") { + return nil, l.eu(&azcore.Response{Response: resp}) + } + return l.FinalResponse(ctx, respType) } - d := frequency + d := freq if retryAfter := azcore.RetryAfter(resp); retryAfter > 0 { azcore.Log().Writef(azcore.LogLongRunningOperation, "Retry-After delay for %s", retryAfter.String()) d = retryAfter } else { azcore.Log().Writef(azcore.LogLongRunningOperation, "delay for %s", d.String()) } - err = delay(ctx, d) - if err != nil { + if err = delay(ctx, d); err != nil { logPollUntilDoneExit(err) return nil, err } } - logPollUntilDoneExit(p.pt.pollingStatus()) - return p.FinalResponse(ctx, pipeline, respType) -} - -func delay(ctx context.Context, delay time.Duration) error { - select { - case <-time.After(delay): - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// pollForStatus queries the service to see if the operation has completed. -func (p *poller) pollForStatus(ctx context.Context, pl azcore.Pipeline, pt pollingTracker) bool { - if pt.hasTerminated() { - return true - } - if err := pt.pollForStatus(ctx, pl); err != nil { - return true - } - if err := pt.checkForErrors(); err != nil { - return true - } - if err := pt.updatePollingState(pt.provisioningStateApplicable()); err != nil { - return true - } - if err := pt.initPollingMethod(); err != nil { - return true - } - if err := pt.updatePollingMethod(); err != nil { - return true - } - // VERB status method URL - azcore.Log().Writef(azcore.LogLongRunningOperation, "%s %s %s %s", - strings.ToUpper(pt.pollerMethodVerb()), pt.pollingStatus(), pt.pollingMethod(), pt.pollingURL()) - return pt.hasTerminated() -} - -type pollingTracker interface { - // these methods can differ per tracker - - // checks the response headers and status code to determine the polling mechanism - updatePollingMethod() error - - // checks the response for tracker-specific error conditions - checkForErrors() error - - // returns true if provisioning state should be checked - provisioningStateApplicable() bool - - // methods common to all trackers - - // initializes a tracker's polling URL and method, called for each iteration. - // these values can be overridden by each polling tracker as required. - initPollingMethod() error - - // initializes the tracker's internal state, call this when the tracker is created - initializeState() error - - // makes an HTTP request to check the status of the LRO - pollForStatus(ctx context.Context, client azcore.Pipeline) error - - // updates internal tracker state, call this after each call to pollForStatus - updatePollingState(provStateApl bool) error - - // returns the error response from the service, can be nil - pollingError() error - - // returns the polling method being used - pollingMethod() pollingMethodType - - // returns the state of the LRO as returned from the service - pollingStatus() string - - // returns the URL used for polling status - pollingURL() string - - // returns the client poller type - pollerType() string - - // returns the URL used for the final GET to retrieve the resource - finalGetURL() string - - // returns true if the LRO is in a terminal state - hasTerminated() bool - - // returns true if the LRO is in a failed terminal state - hasFailed() bool - - // returns true if the LRO is in a successful terminal state - hasSucceeded() bool - - // returns the cached HTTP response after a call to pollForStatus(), can be nil - latestResponse() *azcore.Response - - // converts an *azcore.Response to an error - handleError(resp *azcore.Response) error - - // sets the FinalGetURI to the value pointed to in FinalStateVia - setFinalState() error - - // returns the verb used with the initial request - pollerMethodVerb() string } -type methodErrorHandler func(resp *azcore.Response) error - -type pollingTrackerBase struct { - // resp is the last response, either from the submission of the LRO or from polling - resp *azcore.Response - - // PollerType is the name of the type of poller that is created - PollerType string `json:"pollerType"` - - // errorHandler is the method to invoke to unmarshall an error response - errorHandler methodErrorHandler - - // method is the HTTP verb, this is needed for deserialization - Method string `json:"method"` - - // rawBody is the raw JSON response body - rawBody map[string]interface{} +var _ azcore.Poller = (*LROPoller)(nil) - // denotes if polling is using async-operation or location header - Pm pollingMethodType `json:"pollingMethod"` - - // the URL to poll for status - URI string `json:"pollingURI"` - - // the state of the LRO as returned from the service - State string `json:"lroState"` - - // the URL to GET for the final result - FinalGetURI string `json:"resultURI"` - - // stores the name of the header that the final get should be performed on, - // can be empty which will go to default behavior - FinalStateVia string `json:"finalStateVia"` - - // the original request URL of the initial request for the polling operation - OriginalURI string `json:"originalURI"` - - // used to hold an error object returned from the service - Err error `json:"error,omitempty"` -} - -func newPollingTrackerBase(pollerType, finalState string, resp *azcore.Response, errorHandler methodErrorHandler) pollingTrackerBase { - return pollingTrackerBase{ - PollerType: pollerType, - FinalStateVia: finalState, - OriginalURI: resp.Request.URL.String(), - resp: resp, - errorHandler: errorHandler, - } +// abstracts the differences between concrete poller types +type lroPoller interface { + Done() bool + Update(resp *azcore.Response) error + FinalGetURL() string + URL() string + Status() string } -func (pt *pollingTrackerBase) initializeState() error { - // determine the initial polling state based on response body and/or HTTP status - // code. this is applicable to the initial LRO response, not polling responses! - pt.Method = pt.resp.Request.Method - if err := pt.updateRawBody(); err != nil { - pt.Err = err - return err - } - switch pt.resp.StatusCode { - case http.StatusOK: - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - if pt.hasFailed() { - pt.updateErrorFromResponse() - return pt.pollingError() - } - } else { - pt.State = operationSucceeded - } - case http.StatusCreated: - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - } else { - pt.State = operationInProgress - } - case http.StatusAccepted: - pt.State = operationInProgress - case http.StatusNoContent: - pt.State = operationSucceeded - default: - pt.State = operationFailed - pt.updateErrorFromResponse() - return pt.pollingError() - } - return pt.initPollingMethod() -} +// ==================================================================================================== -func (pt *pollingTrackerBase) getProvisioningState() *string { - if pt.rawBody != nil && pt.rawBody["properties"] != nil { - p := pt.rawBody["properties"].(map[string]interface{}) - if ps := p["provisioningState"]; ps != nil { - s := ps.(string) - return &s - } - } - return nil -} +// used if the operation synchronously completed +type nopPoller struct{} -func (pt *pollingTrackerBase) updateRawBody() error { - pt.rawBody = map[string]interface{}{} - if pt.resp.ContentLength != 0 { - defer pt.resp.Body.Close() - b, err := ioutil.ReadAll(pt.resp.Body) - if err != nil { - pt.Err = err - return pt.Err - } - // seek back to the beginning of the body or reassign the information to the body - if seeker, ok := pt.resp.Body.(io.Seeker); ok { - _, err = seeker.Seek(0, io.SeekStart) - if err != nil { - pt.Err = err - return pt.Err - } - } else { - // put the body back so it's available to other callers - pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - // observed in 204 responses over HTTP/2.0; the content length is -1 but body is empty - if len(b) == 0 { - return nil - } - if err = json.Unmarshal(b, &pt.rawBody); err != nil { - pt.Err = err - return pt.Err - } - } - return nil +func (*nopPoller) URL() string { + return "" } -func (pt *pollingTrackerBase) pollForStatus(ctx context.Context, client azcore.Pipeline) error { - req, err := azcore.NewRequest(ctx, http.MethodGet, pt.URI) - if err != nil { - pt.Err = err - return err - } - resp, err := client.Do(req) - pt.resp = resp - if err != nil { - pt.Err = err - return pt.Err - } - if pt.resp.HasStatusCode(pollingCodes[:]...) { - // reset the service error on success case - pt.Err = nil - err = pt.updateRawBody() - } else { - // check response body for error content - pt.updateErrorFromResponse() - err = pt.pollingError() - } - return err +func (*nopPoller) Done() bool { + return true } -// attempts to unmarshal a ServiceError type from the response body. -// if that fails then make a best attempt at creating something meaningful. -// NOTE: this assumes that the async operation has failed. -func (pt *pollingTrackerBase) updateErrorFromResponse() { - pt.Err = pt.errorHandler(pt.resp) +func (*nopPoller) Succeeded() bool { + return true } -func (pt *pollingTrackerBase) updatePollingState(provStateApl bool) error { - if pt.Pm == pollingAsyncOperation && pt.rawBody["status"] != nil { - pt.State = pt.rawBody["status"].(string) - } else { - if pt.resp.StatusCode == http.StatusAccepted { - pt.State = operationInProgress - } else if provStateApl { - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - } else { - pt.State = operationSucceeded - } - } else { - pt.Err = fmt.Errorf("the response from the async operation has an invalid status code: %d", pt.resp.StatusCode) - return pt.Err - } - } - // if the operation has failed update the error state - if pt.hasFailed() { - pt.updateErrorFromResponse() - } +func (*nopPoller) Update(*azcore.Response) error { return nil } -func (pt *pollingTrackerBase) pollingError() error { - return pt.Err -} - -func (pt *pollingTrackerBase) pollingMethod() pollingMethodType { - return pt.Pm -} - -func (pt *pollingTrackerBase) pollingStatus() string { - return pt.State -} - -func (pt *pollingTrackerBase) pollingURL() string { - return pt.URI +func (*nopPoller) FinalGetURL() string { + return "" } -func (pt *pollingTrackerBase) pollerType() string { - return pt.PollerType +func (*nopPoller) Status() string { + return "Succeeded" } -func (pt *pollingTrackerBase) finalGetURL() string { - return pt.FinalGetURI -} - -func (pt *pollingTrackerBase) hasTerminated() bool { - return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) || strings.EqualFold(pt.State, operationSucceeded) -} - -func (pt *pollingTrackerBase) hasFailed() bool { - return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) -} - -func (pt *pollingTrackerBase) hasSucceeded() bool { - return strings.EqualFold(pt.State, operationSucceeded) -} - -func (pt *pollingTrackerBase) latestResponse() *azcore.Response { - return pt.resp -} - -// error checking common to all trackers -func (pt *pollingTrackerBase) baseCheckForErrors() error { - // for Azure-AsyncOperations the response body cannot be nil or empty - if pt.Pm == pollingAsyncOperation { - if pt.resp.Body == nil || pt.resp.ContentLength == 0 { - pt.Err = errors.New("for Azure-AsyncOperation response body cannot be nil") - return pt.Err - } - if pt.rawBody["status"] == nil { - pt.Err = errors.New("missing status property in Azure-AsyncOperation response body") - return pt.Err - } - } - return nil +// returns true if the LRO response contains a valid HTTP status code +func lroStatusCodeValid(resp *azcore.Response) bool { + return resp.HasStatusCode(http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) } -// default initialization of polling URL/method. each verb tracker will update this as required. -func (pt *pollingTrackerBase) initPollingMethod() error { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - return nil - } - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh != "" { - pt.URI = lh - pt.Pm = pollingLocation - return nil - } - // it's ok if we didn't find a polling header, this will be handled elsewhere - return nil -} - -func (pt *pollingTrackerBase) handleError(resp *azcore.Response) error { - return pt.errorHandler(resp) -} - -func (pt *pollingTrackerBase) setFinalState() error { - if len(pt.FinalStateVia) == 0 { +func delay(ctx context.Context, delay time.Duration) error { + select { + case <-time.After(delay): return nil + case <-ctx.Done(): + return ctx.Err() } - if pt.FinalStateVia == "azure-async-operation" { - ao, err := getURLFromAsyncOpHeader(pt.latestResponse()) - if err != nil { - return err - } - if ao != "" { - pt.FinalGetURI = ao - } - } else if pt.FinalStateVia == "location" { - lh, err := getURLFromLocationHeader(pt.latestResponse()) - if err != nil { - return err - } - if lh != "" { - pt.FinalGetURI = lh - } - } else if pt.FinalStateVia == "original-uri" { - pt.FinalGetURI = pt.OriginalURI - } - return nil -} - -func (pt *pollingTrackerBase) pollerMethodVerb() string { - return pt.Method -} - -// DELETE - -type pollingTrackerDelete struct { - pollingTrackerBase -} - -func (pt *pollingTrackerDelete) updatePollingMethod() error { - // for 201 the Location header is required - if pt.resp.StatusCode == http.StatusCreated { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("missing Location header in 201 response") - return pt.Err - } else { - pt.URI = lh - } - pt.Pm = pollingLocation - pt.FinalGetURI = pt.URI - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - // when both headers are returned we use the value in the Location header for the final GET - pt.FinalGetURI = lh - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerDelete) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerDelete) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent -} - -// PATCH - -type pollingTrackerPatch struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPatch) updatePollingMethod() error { - // by default we can use the original URL for polling and final GET - if pt.URI == "" { - pt.URI = pt.resp.Request.URL.String() - } - if pt.FinalGetURI == "" { - pt.FinalGetURI = pt.resp.Request.URL.String() - } - if pt.Pm == pollingUnknown { - pt.Pm = pollingRequestURI - } - // for 201 it's permissible for no headers to be returned - if pt.resp.StatusCode == http.StatusCreated { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - // note the absence of the "final GET" mechanism for PATCH - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - if ao == "" { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } else { - pt.URI = lh - pt.Pm = pollingLocation - } - } - } - return nil -} - -func (pt *pollingTrackerPatch) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerPatch) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated -} - -// POST - -type pollingTrackerPost struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPost) updatePollingMethod() error { - // 201 requires Location header - if pt.resp.StatusCode == http.StatusCreated { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("missing Location header in 201 response") - return pt.Err - } else { - pt.URI = lh - pt.FinalGetURI = lh - pt.Pm = pollingLocation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - // when both headers are returned we use the value in the Location header for the final GET - pt.FinalGetURI = lh - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerPost) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerPost) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent -} - -// PUT - -type pollingTrackerPut struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPut) updatePollingMethod() error { - // by default we can use the original URL for polling and final GET - if pt.URI == "" { - pt.URI = pt.resp.Request.URL.String() - } - if pt.FinalGetURI == "" { - pt.FinalGetURI = pt.resp.Request.URL.String() - } - if pt.Pm == pollingUnknown { - pt.Pm = pollingRequestURI - } - // for 201 it's permissible for no headers to be returned - if pt.resp.StatusCode == http.StatusCreated { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerPut) checkForErrors() error { - err := pt.baseCheckForErrors() - if err != nil { - pt.Err = err - return err - } - // if there are no LRO headers then the body cannot be empty - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } - lh, err := getURLFromLocationHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } - if ao == "" && lh == "" && len(pt.rawBody) == 0 { - pt.Err = errors.New("the response did not contain a body") - return pt.Err - } - return nil -} - -func (pt *pollingTrackerPut) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated -} - -// gets the polling URL from the Azure-AsyncOperation header. -// ensures the URL is well-formed and absolute. -func getURLFromAsyncOpHeader(resp *azcore.Response) (string, error) { - s := resp.Header.Get(headerAsyncOperation) - if s == "" { - return "", nil - } - if !isValidURL(s) { - return "", fmt.Errorf("invalid polling URL '%s'", s) - } - return s, nil -} - -// gets the polling URL from the Location header. -// ensures the URL is well-formed and absolute. -func getURLFromLocationHeader(resp *azcore.Response) (string, error) { - s := resp.Header.Get(headerLocation) - if s == "" { - return "", nil - } - if !isValidURL(s) { - return "", fmt.Errorf("invalid polling URL '%s'", s) - } - return s, nil -} - -// verify that the URL is valid and absolute -func isValidURL(s string) bool { - u, err := url.Parse(s) - return err == nil && u.IsAbs() } - -// pollingMethodType defines a type used for enumerating polling mechanisms. -type pollingMethodType string - -const ( - // pollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. - pollingAsyncOperation pollingMethodType = "AsyncOperation" - - // pollingLocation indicates the polling method uses the Location header. - pollingLocation pollingMethodType = "Location" - - // pollingRequestURI indicates the polling method uses the original request URI. - pollingRequestURI pollingMethodType = "RequestURI" - - // pollingUnknown indicates an unknown polling method and is the default value. - pollingUnknown pollingMethodType = "" -) diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index fc250c84ed69..f2439a1a6589 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -7,196 +7,242 @@ package armcore import ( "context" - "fmt" + "io" "net/http" + "strings" "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) const ( - mockSuccessResp = `{"field": "success"}` + provStateStarted = `{ "properties": { "provisioningState": "Started" } }` + provStateUpdating = `{ "properties": { "provisioningState": "Updating" } }` + provStateSucceeded = `{ "properties": { "provisioningState": "Succeeded" }, "field": "value" }` + statusInProgress = `{ "status": "InProgress" }` + statusSucceeded = `{ "status": "Succeeded" }` + successResp = `{ "field": "value" }` + errorResp = `{ "error": "the operation failed" }` ) type mockType struct { Field *string `json:"field,omitempty"` } +type mockError struct { + Msg string `json:"error"` +} + +func (m mockError) Error() string { + return m.Msg +} + func getPipeline(srv *mock.Server) azcore.Pipeline { return azcore.NewPipeline( srv, - azcore.NewRetryPolicy(nil), azcore.NewLogPolicy(nil)) } func handleError(resp *azcore.Response) error { - return fmt.Errorf("error status: %d", resp.StatusCode) + var me mockError + if err := resp.UnmarshalAsJSON(&me); err != nil { + return err + } + return me } -func TestNewPollerTracker(t *testing.T) { +func initialResponse(method, u string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, u, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + Request: req, + }, + } +} + +func TestNewLROPollerAsync(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPost, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - pller, err := NewPoller("testPoller", "", resp, handleError) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - pt := pller.(*poller).pt.(*pollingTrackerPost) - if pt.PollerType != "testPoller" { - t.Fatal("wrong poller type assigned") + t.Fatal(err) } - if pt.resp != resp { - t.Fatal("wrong response assigned") + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) } - if pt.Method != "POST" { - t.Fatal("wrong poller method") + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestResumeTokenFail(t *testing.T) { +func TestNewLROPollerBody(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithBody([]byte(`{"field": "success"}`))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPost, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*body.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - poller, err := NewPoller("testPoller", "", resp, handleError) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - tk, err := poller.ResumeToken() - if err == nil { - t.Fatalf("Expected an error but did not receive one") + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) } - if tk != "" { - t.Fatal("did not expect to receive resume token for a poller in a terminal state") + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestPollUntilDone(t *testing.T) { +func TestNewLROPollerLoc(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderLocation, srv.URL()) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*loc.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - poller, err := NewPoller("testPoller", "", resp, handleError) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - if pollerResp == nil { - t.Fatal("Unexpected nil response") - } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") - } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestPutFinalResponseCheck(t *testing.T) { +func TestNewLROPollerNop(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(`{"other": "other"}`))) - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) + resp := initialResponse(http.MethodPost, srv.URL(), strings.NewReader(successResp)) + resp.StatusCode = http.StatusOK + poller, err := NewLROPoller("pollerID", "", resp, getPipeline(srv), handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - poller, err := NewPoller("testPoller", "", resp, handleError) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*nopPoller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + tk, err := poller.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") } - if pollerResp == nil { - t.Fatal("Unexpected nil response") + if tk != "" { + t.Fatal("expected empty token") } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestNewPollerFromResumeTokenTracker(t *testing.T) { +func TestNewLROPollerInitialRetryAfter(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(`{"field": "success"}`))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - resp, err := p.Do(req) + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.Header.Set("Retry-After", "1") + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - poller, err := NewPoller("testPoller", "", resp, handleError) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - tk, err := poller.ResumeToken() + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - poller, err = NewPollerFromResumeToken("testPoller", tk, handleError) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) +} + +func TestNewLROPollerFailed(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(errorResp)), mock.WithStatusCode(http.StatusBadRequest)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - if pollerResp == nil { - t.Fatal("Unexpected nil response") + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err == nil { + t.Fatal(err) } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + if _, ok := err.(mockError); !ok { + t.Fatalf("unexpected error type %T", err) } } diff --git a/sdk/armcore/version.go b/sdk/armcore/version.go index c100fa154219..acdd3fcf5579 100644 --- a/sdk/armcore/version.go +++ b/sdk/armcore/version.go @@ -10,5 +10,5 @@ const ( UserAgent = "armcore/" + Version // Version is the semantic version (see http://semver.org) of this module. - Version = "v0.7.1" + Version = "v0.8.0" ) From a42cca6280e9d58a034e443ea2fa7776f9cdb724 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Mon, 7 Jun 2021 15:45:22 -0700 Subject: [PATCH 02/21] Fix NopCloser package for earlier versions of Go --- sdk/armcore/internal/pollers/async/async_test.go | 5 +++-- sdk/armcore/internal/pollers/body/body_test.go | 5 +++-- sdk/armcore/internal/pollers/pollers_test.go | 6 +++--- sdk/armcore/poller_test.go | 3 ++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async_test.go b/sdk/armcore/internal/pollers/async/async_test.go index 30b88a4e2f5d..c9f252d1cb15 100644 --- a/sdk/armcore/internal/pollers/async/async_test.go +++ b/sdk/armcore/internal/pollers/async/async_test.go @@ -7,6 +7,7 @@ package async import ( "io" + "io/ioutil" "net/http" "strings" "testing" @@ -27,7 +28,7 @@ func initialResponse(method string, resp io.Reader) *azcore.Response { } return &azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(resp), + Body: ioutil.NopCloser(resp), Header: http.Header{}, Request: req, }, @@ -37,7 +38,7 @@ func initialResponse(method string, resp io.Reader) *azcore.Response { func pollingResponse(resp io.Reader) *azcore.Response { return &azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(resp), + Body: ioutil.NopCloser(resp), Header: http.Header{}, }, } diff --git a/sdk/armcore/internal/pollers/body/body_test.go b/sdk/armcore/internal/pollers/body/body_test.go index d67a74f02a0f..cc468bb91026 100644 --- a/sdk/armcore/internal/pollers/body/body_test.go +++ b/sdk/armcore/internal/pollers/body/body_test.go @@ -7,6 +7,7 @@ package body import ( "io" + "io/ioutil" "net/http" "strings" "testing" @@ -25,7 +26,7 @@ func initialResponse(method string, resp io.Reader) *azcore.Response { } return &azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(resp), + Body: ioutil.NopCloser(resp), Header: http.Header{}, Request: req, }, @@ -35,7 +36,7 @@ func initialResponse(method string, resp io.Reader) *azcore.Response { func pollingResponse(resp io.Reader) *azcore.Response { return &azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(resp), + Body: ioutil.NopCloser(resp), Header: http.Header{}, }, } diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index eb31235a84eb..80d06df2e7f6 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -7,7 +7,7 @@ package pollers import ( "errors" - "io" + "io/ioutil" "net/http" "strings" "testing" @@ -34,7 +34,7 @@ func TestGetStatusSuccess(t *testing.T) { const jsonBody = `{ "status": "InProgress" }` resp := azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(strings.NewReader(jsonBody)), + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), }, } status, err := GetStatus(&resp) @@ -65,7 +65,7 @@ func TestGetProvisioningState(t *testing.T) { const jsonBody = `{ "properties": { "provisioningState": "Canceled" } }` resp := azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(strings.NewReader(jsonBody)), + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), }, } state, err := GetProvisioningState(&resp) diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index f2439a1a6589..e08707354790 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -8,6 +8,7 @@ package armcore import ( "context" "io" + "io/ioutil" "net/http" "strings" "testing" @@ -64,7 +65,7 @@ func initialResponse(method, u string, resp io.Reader) *azcore.Response { } return &azcore.Response{ Response: &http.Response{ - Body: io.NopCloser(resp), + Body: ioutil.NopCloser(resp), Header: http.Header{}, Request: req, }, From 5cca901cc749303d317b61ab4b96992bfce1b5bd Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 09:15:01 -0700 Subject: [PATCH 03/21] handle empty response when a model is provided log a message in this case --- sdk/armcore/poller.go | 5 ++++- sdk/armcore/poller_test.go | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index 73c75692b07c..b4ad987cb910 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -169,9 +169,12 @@ func (l *LROPoller) FinalResponse(ctx context.Context, respType interface{}) (*h if !l.Done() { return nil, errors.New("cannot return a final response from a poller in a non-terminal state") } - // if there's nothing to unmarshall into just return the final response + // if there's nothing to unmarshall into or no response body just return the final response if respType == nil { return l.resp.Response, nil + } else if l.resp.StatusCode == http.StatusNoContent || l.resp.ContentLength == 0 { + azcore.Log().Write(azcore.LogLongRunningOperation, "final response specifies a response type but no payload was received") + return l.resp.Response, nil } if u := l.lro.FinalGetURL(); u != "" { azcore.Log().Write(azcore.LogLongRunningOperation, "Performing final GET.") diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index e08707354790..589458f229fc 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -65,9 +65,10 @@ func initialResponse(method, u string, resp io.Reader) *azcore.Response { } return &azcore.Response{ Response: &http.Response{ - Body: ioutil.NopCloser(resp), - Header: http.Header{}, - Request: req, + Body: ioutil.NopCloser(resp), + ContentLength: -1, + Header: http.Header{}, + Request: req, }, } } @@ -247,3 +248,33 @@ func TestNewLROPollerFailed(t *testing.T) { t.Fatalf("unexpected error type %T", err) } } + +func TestNewLROPollerSuccessNoContent(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating))) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if _, ok := poller.lro.(*body.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if result.Field != nil { + t.Fatal("expected nil result") + } +} From 1de83a6f8519425d55e13c6db90c53dac5875878 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 14:30:58 -0700 Subject: [PATCH 04/21] update location polling URL --- sdk/armcore/internal/pollers/loc/loc.go | 4 ++++ sdk/armcore/internal/pollers/loc/loc_test.go | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index 1b941dd08254..c01b5b33f425 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -47,6 +47,10 @@ func (p *Poller) Done() bool { // Update updates the Poller from the polling response. func (p *Poller) Update(resp *azcore.Response) error { + // location polling can return an updated polling URL + if h := resp.Header.Get(pollers.HeaderLocation); h != "" { + p.PollURL = h + } // any 2xx code other than 202 indicates success if resp.HasStatusCode(http.StatusOK, http.StatusCreated, http.StatusNoContent) { p.CurState = "Succeeded" diff --git a/sdk/armcore/internal/pollers/loc/loc_test.go b/sdk/armcore/internal/pollers/loc/loc_test.go index bc205c6bb487..aa5986489c06 100644 --- a/sdk/armcore/internal/pollers/loc/loc_test.go +++ b/sdk/armcore/internal/pollers/loc/loc_test.go @@ -14,7 +14,8 @@ import ( ) const ( - fakePollingURL = "https://foo.bar.baz/status" + fakePollingURL1 = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/updated" fakeResourceURL = "https://foo.bar.baz/resource" ) @@ -46,7 +47,7 @@ func TestApplicable(t *testing.T) { if Applicable(&resp) { t.Fatal("missing Location should not be applicable") } - resp.Response.Header.Set(pollers.HeaderLocation, fakePollingURL) + resp.Response.Header.Set(pollers.HeaderLocation, fakePollingURL1) if !Applicable(&resp) { t.Fatal("having Location should be applicable") } @@ -54,7 +55,7 @@ func TestApplicable(t *testing.T) { func TestNew(t *testing.T) { resp := initialResponse(http.MethodPut) - resp.Header.Set(pollers.HeaderLocation, fakePollingURL) + resp.Header.Set(pollers.HeaderLocation, fakePollingURL1) poller, err := New(resp, "pollerID") if err != nil { t.Fatal(err) @@ -68,7 +69,15 @@ func TestNew(t *testing.T) { if s := poller.Status(); s != "InProgress" { t.Fatalf("unexpected status %s", s) } - if u := poller.URL(); u != fakePollingURL { + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted) + pr.Header.Set(pollers.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { t.Fatalf("unexpected polling URL %s", u) } if err := poller.Update(pollingResponse(http.StatusNoContent)); err != nil { From f0055bdfa16e5db548972c58f892e9e4a5d875a9 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 14:54:51 -0700 Subject: [PATCH 05/21] fix final GET for POST --- sdk/armcore/internal/pollers/async/async.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index 1d9aaacdb111..acfe81b1ac7a 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -87,13 +87,15 @@ func (p *Poller) FinalGetURL() string { // for PATCH and PUT, the final GET is on the original resource URL return p.OrigURL } else if p.Method == http.MethodPost { - // for POST, we need to consult the final-state-via flag - if p.FinalState == finalStateLoc && p.LocURL != "" { - return p.LocURL + if p.FinalState == finalStateAsync { + return "" } else if p.FinalState == finalStateOrig { return p.OrigURL + } else if p.LocURL != "" { + // ideally FinalState would be set to "location" but it isn't always. + // must check last due to more permissive condition. + return p.LocURL } - // finalStateAsync fall through } return "" } From 72a12ea2a99dc709c96ae6601461dca3f316f2fd Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 15:38:33 -0700 Subject: [PATCH 06/21] handle absense of provisioning state for initial response body polling --- sdk/armcore/internal/pollers/body/body.go | 23 +++++++++-- .../internal/pollers/body/body_test.go | 38 ++++++++++++++++--- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go index 331ac63d4aa1..dde56c1424bf 100644 --- a/sdk/armcore/internal/pollers/body/body.go +++ b/sdk/armcore/internal/pollers/body/body.go @@ -35,12 +35,27 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { Type: pollers.MakeID(pollerID, "body"), PollURL: resp.Request.URL.String(), } - // the initial response must contain a provisioning state - state, err := pollers.GetProvisioningState(resp) - if err != nil { + // default initial state to InProgress. depending on the HTTP + // status code and provisioning state, we might change the value. + curState := "InProgress" + provState, err := pollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, pollers.ErrNoProvisioningState) { return nil, err } - p.CurState = state + if resp.StatusCode == http.StatusCreated && provState != "" { + // absense of provisioning state is ok for a 201, means the operation is in progress + curState = provState + } else if resp.StatusCode == http.StatusOK { + if provState != "" { + curState = provState + } else if provState == "" { + // for a 200, absense of provisioning state indicates success + curState = "Succeeded" + } + } else if resp.StatusCode == http.StatusNoContent { + curState = "Succeeded" + } + p.CurState = curState return p, nil } diff --git a/sdk/armcore/internal/pollers/body/body_test.go b/sdk/armcore/internal/pollers/body/body_test.go index cc468bb91026..a4a08ef04897 100644 --- a/sdk/armcore/internal/pollers/body/body_test.go +++ b/sdk/armcore/internal/pollers/body/body_test.go @@ -67,6 +67,7 @@ func TestApplicable(t *testing.T) { func TestNew(t *testing.T) { const jsonBody = `{ "properties": { "provisioningState": "Started" } }` resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusCreated poller, err := New(resp, "pollerID") if err != nil { t.Fatal(err) @@ -94,6 +95,7 @@ func TestNew(t *testing.T) { func TestUpdateNoProvState(t *testing.T) { const jsonBody = `{ "properties": { "provisioningState": "Started" } }` resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK poller, err := New(resp, "pollerID") if err != nil { t.Fatal(err) @@ -118,14 +120,38 @@ func TestUpdateNoProvState(t *testing.T) { } } -func TestNewFail(t *testing.T) { - // missing provisioning state on initial response +func TestNewNoInitialProvStateOK(t *testing.T) { resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusOK poller, err := New(resp, "pollerID") - if err == nil { - t.Fatal("unexpected nil error") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewNoInitialProvStateNC(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusNoContent + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") } - if poller != nil { - t.Fatal("expected nil poller") + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) } } From c23a4999c6cdc1ad2857cf6c40ddfc1a030e9880 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 15:45:57 -0700 Subject: [PATCH 07/21] verify polling URLs --- sdk/armcore/internal/pollers/async/async.go | 4 ++++ sdk/armcore/internal/pollers/loc/loc.go | 7 ++++++- sdk/armcore/internal/pollers/pollers.go | 7 +++++++ sdk/armcore/internal/pollers/pollers_test.go | 9 +++++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index acfe81b1ac7a..5e25eef930ef 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -7,6 +7,7 @@ package async import ( "errors" + "fmt" "net/http" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -42,6 +43,9 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er if asyncURL == "" { return nil, errors.New("response is missing Azure-AsyncOperation header") } + if !pollers.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } p := &Poller{ Type: pollers.MakeID(pollerID, "async"), AsyncURL: asyncURL, diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index c01b5b33f425..853410bd66ea 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -6,6 +6,7 @@ package loc import ( + "fmt" "net/http" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -27,9 +28,13 @@ type Poller struct { // New creates a new Poller from the provided initial response. func New(resp *azcore.Response, pollerID string) (*Poller, error) { azcore.Log().Write(azcore.LogLongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(pollers.HeaderLocation) + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } p := &Poller{ Type: pollers.MakeID(pollerID, "loc"), - PollURL: resp.Header.Get(pollers.HeaderLocation), + PollURL: locURL, CurState: "InProgress", } return p, nil diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 0ee2c0c0e966..0f0081c80012 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/url" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -108,6 +109,12 @@ func GetProvisioningState(resp *azcore.Response) (string, error) { return "", ErrNoProvisioningState } +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + // MakeID returns the unique poller identifier in the format pollerID;poller. func MakeID(pollerID string, kind string) string { return fmt.Sprintf("%s;%s", pollerID, kind) diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index 80d06df2e7f6..6a502df99ba9 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -109,3 +109,12 @@ func TestMakeID(t *testing.T) { t.Fatalf("unexpected poller kind %s", p) } } + +func TestIsValidURL(t *testing.T) { + if IsValidURL("/foo") { + t.Fatal("unexpected valid URL") + } + if !IsValidURL("https://foo.bar/baz") { + t.Fatal("expected valid URL") + } +} From d62dc84ee3fb1d675b8f2604d8efc4d52f9aa99e Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 16:58:01 -0700 Subject: [PATCH 08/21] move checking of errors to Poll method --- sdk/armcore/internal/pollers/pollers.go | 13 ++++++++- sdk/armcore/internal/pollers/pollers_test.go | 9 ++++++ sdk/armcore/poller.go | 9 ++++-- sdk/armcore/poller_test.go | 30 +++++++++++++++++++- 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 0f0081c80012..10d1cf505472 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -22,6 +22,12 @@ const ( HeaderLocation = "Location" ) +const ( + statusSucceeded = "succeeded" + statusCanceled = "canceled" + statusFailed = "failed" +) + // reads the response body into a raw JSON object. // returns an empty object if there was no content. func getJSON(resp *azcore.Response) (map[string]interface{}, error) { @@ -79,7 +85,12 @@ func status(jsonBody map[string]interface{}) string { // IsTerminalState returns true if the LRO's state is terminal. func IsTerminalState(s string) bool { - return strings.EqualFold(s, "succeeded") || strings.EqualFold(s, "failed") || strings.EqualFold(s, "canceled") + return strings.EqualFold(s, statusSucceeded) || strings.EqualFold(s, statusFailed) || strings.EqualFold(s, statusCanceled) +} + +// Failed returns true if the LRO's state is terminal failure. +func Failed(s string) bool { + return strings.EqualFold(s, statusFailed) || strings.EqualFold(s, statusCanceled) } // GetStatus returns the LRO's status from the response body. diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index 6a502df99ba9..14b8bee1c471 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -118,3 +118,12 @@ func TestIsValidURL(t *testing.T) { t.Fatal("expected valid URL") } } + +func TestFailed(t *testing.T) { + if Failed("Succeeded") || Failed("Updating") { + t.Fatal("unexpected failure") + } + if !Failed("failed") { + t.Fatal("expected failure") + } +} diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index b4ad987cb910..1254e0d7dae0 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/async" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/body" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/loc" @@ -148,6 +149,11 @@ func (l *LROPoller) Poll(ctx context.Context) (*http.Response, error) { return nil, err } l.resp = resp + if pollers.Failed(l.lro.Status()) { + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } return l.resp.Response, nil } @@ -234,9 +240,6 @@ func (l *LROPoller) PollUntilDone(ctx context.Context, freq time.Duration, respT status := l.lro.Status() azcore.Log().Writef(azcore.LogLongRunningOperation, "Status %s", status) logPollUntilDoneExit(status) - if !strings.EqualFold(status, "succeeded") { - return nil, l.eu(&azcore.Response{Response: resp}) - } return l.FinalResponse(ctx, respType) } d := freq diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index 589458f229fc..cd836441e8ee 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -26,8 +26,10 @@ const ( provStateStarted = `{ "properties": { "provisioningState": "Started" } }` provStateUpdating = `{ "properties": { "provisioningState": "Updating" } }` provStateSucceeded = `{ "properties": { "provisioningState": "Succeeded" }, "field": "value" }` + provStateFailed = `{ "properties": { "provisioningState": "Failed" } }` statusInProgress = `{ "status": "InProgress" }` statusSucceeded = `{ "status": "Succeeded" }` + statusCanceled = `{ "status": "Canceled" }` successResp = `{ "field": "value" }` errorResp = `{ "error": "the operation failed" }` ) @@ -223,7 +225,33 @@ func TestNewLROPollerInitialRetryAfter(t *testing.T) { } } -func TestNewLROPollerFailed(t *testing.T) { +func TestNewLROPollerCanceled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusCanceled)), mock.WithStatusCode(http.StatusOK)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) + } + _, err = poller.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = poller.Poll(context.Background()) + if err == nil { + t.Fatal("unexpected nil error") + } +} + +func TestNewLROPollerFailedWithError(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) From 111758a7e34e351f56d0d69e74164892ebc2d354 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 8 Jun 2021 17:01:42 -0700 Subject: [PATCH 09/21] move logging of status to Poll() --- sdk/armcore/poller.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index 1254e0d7dae0..0917c7fcf166 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -149,6 +149,7 @@ func (l *LROPoller) Poll(ctx context.Context) (*http.Response, error) { return nil, err } l.resp = resp + azcore.Log().Writef(azcore.LogLongRunningOperation, "Status %s", l.lro.Status()) if pollers.Failed(l.lro.Status()) { l.err = l.eu(resp) l.resp = nil @@ -237,9 +238,7 @@ func (l *LROPoller) PollUntilDone(ctx context.Context, freq time.Duration, respT return nil, err } if l.Done() { - status := l.lro.Status() - azcore.Log().Writef(azcore.LogLongRunningOperation, "Status %s", status) - logPollUntilDoneExit(status) + logPollUntilDoneExit(l.lro.Status()) return l.FinalResponse(ctx, respType) } d := freq From b35dcff6b8c6eec0f48f6198dff1867d95fb8d9e Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 08:17:37 -0700 Subject: [PATCH 10/21] export pipeline for pager-poller scenario --- sdk/armcore/poller.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index 0917c7fcf166..a1eaf2f5bf89 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -51,7 +51,7 @@ func NewLROPoller(pollerID string, finalState string, resp *azcore.Response, pl if err != nil { return nil, err } - return &LROPoller{lro: lro, pl: pl, eu: eu, resp: resp}, nil + return &LROPoller{lro: lro, Pipeline: pl, eu: eu, resp: resp}, nil } // NewLROPollerFromResumeToken creates an LROPoller from a resume token string. @@ -99,17 +99,17 @@ func NewLROPollerFromResumeToken(pollerID string, token string, pl azcore.Pipeli if err = json.Unmarshal([]byte(token), lro); err != nil { return nil, err } - return &LROPoller{lro: lro, pl: pl, eu: eu}, nil + return &LROPoller{lro: lro, Pipeline: pl, eu: eu}, nil } // LROPoller encapsulates state and logic for polling on long-running operations. // NOTE: this is only meant for internal use in generated code. type LROPoller struct { - lro lroPoller - pl azcore.Pipeline - eu ErrorUnmarshaller - resp *azcore.Response - err error + Pipeline azcore.Pipeline + lro lroPoller + eu ErrorUnmarshaller + resp *azcore.Response + err error } // Done returns true if the LRO has reached a terminal state. @@ -133,7 +133,7 @@ func (l *LROPoller) Poll(ctx context.Context) (*http.Response, error) { if err != nil { return nil, err } - resp, err := l.pl.Do(req) + resp, err := l.Pipeline.Do(req) if err != nil { // don't update the poller for failed requests return nil, err @@ -189,7 +189,7 @@ func (l *LROPoller) FinalResponse(ctx context.Context, respType interface{}) (*h if err != nil { return nil, err } - resp, err := l.pl.Do(req) + resp, err := l.Pipeline.Do(req) if err != nil { return nil, err } From 9eea8023f2f4be7c303c8ae9b6fd938d6db9aa01 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 09:11:31 -0700 Subject: [PATCH 11/21] fail on a 202 DELETE/POST with no polling URL --- sdk/armcore/poller.go | 4 ++++ sdk/armcore/poller_test.go | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index a1eaf2f5bf89..79a2a20c3423 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -45,6 +45,10 @@ func NewLROPoller(pollerID string, finalState string, resp *azcore.Response, pl // must test body poller last as it's a subset of the other pollers. // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) lro, err = body.New(resp, pollerID) + } else if m := resp.Request.Method; resp.StatusCode == http.StatusAccepted && (m == http.MethodDelete || m == http.MethodPost) { + // if we get here it means we have a 202 with no polling headers. + // for DELETE and POST this is a hard error per ARM RPC spec. + return nil, errors.New("response is missing polling URL") } else { lro = &nopPoller{} } diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index cd836441e8ee..b71aa9f7ed95 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -306,3 +306,18 @@ func TestNewLROPollerSuccessNoContent(t *testing.T) { t.Fatal("expected nil result") } } + +func TestNewLROPollerFail202NoHeaders(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodDelete, srv.URL(), http.NoBody) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} From ab32e4c835232138f96d15fee99643de3e51c394 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 09:32:50 -0700 Subject: [PATCH 12/21] differentiate between no response body and missing states this is important for body polling which can tolerate a missing response body for certain status codes. --- sdk/armcore/internal/pollers/async/async.go | 2 +- sdk/armcore/internal/pollers/body/body.go | 26 +++++-- .../internal/pollers/body/body_test.go | 70 +++++++++++++++++-- sdk/armcore/internal/pollers/pollers.go | 5 +- sdk/armcore/internal/pollers/pollers_test.go | 26 ++++++- 5 files changed, 111 insertions(+), 18 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index 5e25eef930ef..333da390db5a 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -56,7 +56,7 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er } // check for provisioning state state, err := pollers.GetProvisioningState(resp) - if errors.Is(err, pollers.ErrNoProvisioningState) { + if errors.Is(err, pollers.ErrNoBody) || errors.Is(err, pollers.ErrNoProvisioningState) { if resp.Request.Method == http.MethodPut { // initial response for a PUT requires a provisioning state return nil, err diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go index dde56c1424bf..eb3ffad98b22 100644 --- a/sdk/armcore/internal/pollers/body/body.go +++ b/sdk/armcore/internal/pollers/body/body.go @@ -13,6 +13,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) +const ( + stateSucceeded = "Succeeded" + stateInProgress = "InProgress" +) + // Applicable returns true if the LRO is using no headers, just provisioning state. // This is only applicable to PATCH and PUT methods and assumes no polling headers. func Applicable(resp *azcore.Response) bool { @@ -37,9 +42,9 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { } // default initial state to InProgress. depending on the HTTP // status code and provisioning state, we might change the value. - curState := "InProgress" + curState := stateInProgress provState, err := pollers.GetProvisioningState(resp) - if err != nil && !errors.Is(err, pollers.ErrNoProvisioningState) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) && !errors.Is(err, pollers.ErrNoProvisioningState) { return nil, err } if resp.StatusCode == http.StatusCreated && provState != "" { @@ -50,10 +55,10 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { curState = provState } else if provState == "" { // for a 200, absense of provisioning state indicates success - curState = "Succeeded" + curState = stateSucceeded } } else if resp.StatusCode == http.StatusNoContent { - curState = "Succeeded" + curState = stateSucceeded } p.CurState = curState return p, nil @@ -71,10 +76,17 @@ func (p *Poller) Done() bool { // Update updates the Poller from the polling response. func (p *Poller) Update(resp *azcore.Response) error { + if resp.StatusCode == http.StatusNoContent { + p.CurState = stateSucceeded + return nil + } state, err := pollers.GetProvisioningState(resp) - if errors.Is(err, pollers.ErrNoProvisioningState) { - // absense of any provisioning state is considered terminal success - state = "Succeeded" + if errors.Is(err, pollers.ErrNoBody) { + // a missing response body in non-204 case is an error + return err + } else if errors.Is(err, pollers.ErrNoProvisioningState) { + // a response body without provisioning state is considered terminal success + state = stateSucceeded } else if err != nil { return err } diff --git a/sdk/armcore/internal/pollers/body/body_test.go b/sdk/armcore/internal/pollers/body/body_test.go index a4a08ef04897..9cbd5821b028 100644 --- a/sdk/armcore/internal/pollers/body/body_test.go +++ b/sdk/armcore/internal/pollers/body/body_test.go @@ -6,12 +6,14 @@ package body import ( + "errors" "io" "io/ioutil" "net/http" "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) @@ -33,11 +35,12 @@ func initialResponse(method string, resp io.Reader) *azcore.Response { } } -func pollingResponse(resp io.Reader) *azcore.Response { +func pollingResponse(status int, resp io.Reader) *azcore.Response { return &azcore.Response{ Response: &http.Response{ - Body: ioutil.NopCloser(resp), - Header: http.Header{}, + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + StatusCode: status, }, } } @@ -84,7 +87,7 @@ func TestNew(t *testing.T) { if u := poller.URL(); u != fakeResourceURL { t.Fatalf("unexpected polling URL %s", u) } - if err := poller.Update(pollingResponse(strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { t.Fatal(err) } if s := poller.Status(); s != "InProgress" { @@ -92,7 +95,7 @@ func TestNew(t *testing.T) { } } -func TestUpdateNoProvState(t *testing.T) { +func TestUpdateNoProvStateFail(t *testing.T) { const jsonBody = `{ "properties": { "provisioningState": "Started" } }` resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) resp.StatusCode = http.StatusOK @@ -112,12 +115,65 @@ func TestUpdateNoProvState(t *testing.T) { if u := poller.URL(); u != fakeResourceURL { t.Fatalf("unexpected polling URL %s", u) } - if err := poller.Update(pollingResponse(http.NoBody)); err != nil { + err = poller.Update(pollingResponse(http.StatusOK, http.NoBody)) + if err == nil { + t.Fatal("unexpected nil error") + } + if !errors.Is(err, pollers.ErrNoBody) { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestUpdateNoProvStateSuccess(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { t.Fatal(err) } - if s := poller.Status(); s != "Succeeded" { + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{}`))) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateNoProvState204(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { t.Fatalf("unexpected status %s", s) } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)) + if err != nil { + t.Fatal(err) + } } func TestNewNoInitialProvStateOK(t *testing.T) { diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 10d1cf505472..58f0bad57f6d 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -37,7 +37,7 @@ func getJSON(resp *azcore.Response) (map[string]interface{}, error) { } resp.Body.Close() if len(body) == 0 { - return map[string]interface{}{}, nil + return nil, ErrNoBody } // put the body back so it's available to others resp.Body = ioutil.NopCloser(bytes.NewReader(body)) @@ -131,6 +131,9 @@ func MakeID(pollerID string, kind string) string { return fmt.Sprintf("%s;%s", pollerID, kind) } +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") + // ErrNoStatus is returned if the response body didn't contain a status. var ErrNoStatus = errors.New("the response did not contain a status") diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index 14b8bee1c471..4e746efaeab3 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -46,13 +46,35 @@ func TestGetStatusSuccess(t *testing.T) { } } -func TestGetStatusError(t *testing.T) { +func TestGetNoBody(t *testing.T) { resp := azcore.Response{ Response: &http.Response{ Body: http.NoBody, }, } status, err := GetStatus(&resp) + if !errors.Is(err, ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } + status, err = GetProvisioningState(&resp) + if !errors.Is(err, ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } +} + +func TestGetStatusError(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, + } + status, err := GetStatus(&resp) if !errors.Is(err, ErrNoStatus) { t.Fatalf("unexpected error %T", err) } @@ -80,7 +102,7 @@ func TestGetProvisioningState(t *testing.T) { func TestGetProvisioningStateError(t *testing.T) { resp := azcore.Response{ Response: &http.Response{ - Body: http.NoBody, + Body: ioutil.NopCloser(strings.NewReader("{}")), }, } state, err := GetProvisioningState(&resp) From 24e29f9f917822c76361dd981098098aadcc7747 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 09:48:35 -0700 Subject: [PATCH 13/21] relax provisioning state requirement on initial async PUT response --- sdk/armcore/internal/pollers/async/async.go | 7 +++++-- sdk/armcore/internal/pollers/async/async_test.go | 14 +++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index 333da390db5a..d1d187ba91a0 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -57,10 +57,13 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er // check for provisioning state state, err := pollers.GetProvisioningState(resp) if errors.Is(err, pollers.ErrNoBody) || errors.Is(err, pollers.ErrNoProvisioningState) { - if resp.Request.Method == http.MethodPut { + // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST + // contain a provisioning state. to maintain compat with track 1 and other implementations + // we are explicitly relaxing this requirement. + /*if resp.Request.Method == http.MethodPut { // initial response for a PUT requires a provisioning state return nil, err - } + }*/ // for DELETE/PATCH/POST, provisioning state is optional state = "InProgress" } else if err != nil { diff --git a/sdk/armcore/internal/pollers/async/async_test.go b/sdk/armcore/internal/pollers/async/async_test.go index c9f252d1cb15..711285cb2f9a 100644 --- a/sdk/armcore/internal/pollers/async/async_test.go +++ b/sdk/armcore/internal/pollers/async/async_test.go @@ -102,16 +102,20 @@ func TestNewDeleteNoProvState(t *testing.T) { } } -func TestNewFail(t *testing.T) { +func TestNewPutNoProvState(t *testing.T) { // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat resp := initialResponse(http.MethodPut, http.NoBody) resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) poller, err := New(resp, "", "pollerID") - if err == nil { - t.Fatal("unexpected nil error") + if err != nil { + t.Fatal(err) } - if poller != nil { - t.Fatal("expected nil poller") + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) } } From 216a0e55a534142f4ff133de3f646849fd5274e6 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 13:09:07 -0700 Subject: [PATCH 14/21] use provisioningState for loc polling when available --- sdk/armcore/internal/pollers/loc/loc.go | 17 ++++++- sdk/armcore/internal/pollers/loc/loc_test.go | 53 ++++++++++++++++++-- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index 853410bd66ea..07375c37b817 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -6,6 +6,7 @@ package loc import ( + "errors" "fmt" "net/http" @@ -56,12 +57,24 @@ func (p *Poller) Update(resp *azcore.Response) error { if h := resp.Header.Get(pollers.HeaderLocation); h != "" { p.PollURL = h } - // any 2xx code other than 202 indicates success - if resp.HasStatusCode(http.StatusOK, http.StatusCreated, http.StatusNoContent) { + if resp.HasStatusCode(http.StatusOK, http.StatusCreated) { + // if a 200/201 returns a provisioning state, use that instead + state, err := pollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, pollers.ErrNoBody) && !errors.Is(err, pollers.ErrNoProvisioningState) { + return err + } + if state != "" { + p.CurState = state + } else { + // a 200/201 with no provisioning state indicates success + p.CurState = "Succeeded" + } + } else if resp.StatusCode == http.StatusNoContent { p.CurState = "Succeeded" } else if resp.StatusCode > 399 && resp.StatusCode < 500 { p.CurState = "Failed" } + // a 202 falls through, means the LRO is still in progress and we don't check for provisioning state return nil } diff --git a/sdk/armcore/internal/pollers/loc/loc_test.go b/sdk/armcore/internal/pollers/loc/loc_test.go index aa5986489c06..ce14b9393bf3 100644 --- a/sdk/armcore/internal/pollers/loc/loc_test.go +++ b/sdk/armcore/internal/pollers/loc/loc_test.go @@ -6,7 +6,10 @@ package loc import ( + "io" + "io/ioutil" "net/http" + "strings" "testing" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -28,9 +31,10 @@ func initialResponse(method string) *azcore.Response { } } -func pollingResponse(statusCode int) *azcore.Response { +func pollingResponse(statusCode int, body io.Reader) *azcore.Response { return &azcore.Response{ Response: &http.Response{ + Body: ioutil.NopCloser(body), Header: http.Header{}, StatusCode: statusCode, }, @@ -72,7 +76,7 @@ func TestNew(t *testing.T) { if u := poller.URL(); u != fakePollingURL1 { t.Fatalf("unexpected polling URL %s", u) } - pr := pollingResponse(http.StatusAccepted) + pr := pollingResponse(http.StatusAccepted, http.NoBody) pr.Header.Set(pollers.HeaderLocation, fakePollingURL2) if err := poller.Update(pr); err != nil { t.Fatal(err) @@ -80,16 +84,57 @@ func TestNew(t *testing.T) { if u := poller.URL(); u != fakePollingURL2 { t.Fatalf("unexpected polling URL %s", u) } - if err := poller.Update(pollingResponse(http.StatusNoContent)); err != nil { + if err := poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)); err != nil { t.Fatal(err) } if s := poller.Status(); s != "Succeeded" { t.Fatalf("unexpected status %s", s) } - if err := poller.Update(pollingResponse(http.StatusConflict)); err != nil { + if err := poller.Update(pollingResponse(http.StatusConflict, http.NoBody)); err != nil { t.Fatal(err) } if s := poller.Status(); s != "Failed" { t.Fatalf("unexpected status %s", s) } } + +func TestUpdateWithProvState(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(pollers.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(pollers.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "Updating" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusOK, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} From 1c947138e251a728271af635aedd031a5ee74012 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 14:37:52 -0700 Subject: [PATCH 15/21] fix up token decoding --- sdk/armcore/internal/pollers/pollers.go | 22 ++++++++++-- sdk/armcore/internal/pollers/pollers_test.go | 35 +++++++++++++++++++- sdk/armcore/poller.go | 16 ++++----- 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 58f0bad57f6d..4c372f60f135 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -126,9 +126,27 @@ func IsValidURL(s string) bool { return err == nil && u.IsAbs() } -// MakeID returns the unique poller identifier in the format pollerID;poller. +const idSeparator = ";" + +// MakeID returns the poller ID from the provided values. func MakeID(pollerID string, kind string) string { - return fmt.Sprintf("%s;%s", pollerID, kind) + return fmt.Sprintf("%s%s%s", pollerID, idSeparator, kind) +} + +// DecodeID decodes the poller ID, returning [pollerID, kind] or an error. +func DecodeID(tk string) (string, string, error) { + raw := strings.Split(tk, idSeparator) + // strings.Split will include any/all whitespace strings, we want to omit those + parts := []string{} + for _, r := range raw { + if s := strings.TrimSpace(r); s != "" { + parts = append(parts, s) + } + } + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid token %s", tk) + } + return parts[0], parts[1], nil } // ErrNoBody is returned if the response didn't contain a body. diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index 4e746efaeab3..b24418952688 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -120,7 +120,7 @@ func TestMakeID(t *testing.T) { kind = "kind" ) id := MakeID(pollerID, kind) - parts := strings.Split(id, ";") + parts := strings.Split(id, idSeparator) if l := len(parts); l != 2 { t.Fatalf("unexpected length %d", l) } @@ -132,6 +132,39 @@ func TestMakeID(t *testing.T) { } } +func TestDecodeID(t *testing.T) { + id, kind, err := DecodeID("") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid_token;") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID(" ;invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid;token;too") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("pollerID;kind") + if err != nil { + t.Fatal(err) + } + if id != "pollerID" { + t.Fatalf("unexpected ID %s", id) + } + if kind != "kind" { + t.Fatalf("unexpected kin %s", kind) + } +} + func TestIsValidURL(t *testing.T) { if IsValidURL("/foo") { t.Fatal("unexpected valid URL") diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index 79a2a20c3423..ebb7e260044a 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -12,7 +12,6 @@ import ( "fmt" "io/ioutil" "net/http" - "strings" "time" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -76,18 +75,17 @@ func NewLROPollerFromResumeToken(pollerID string, token string, pl azcore.Pipeli if !ok { return nil, fmt.Errorf("invalid type format %T", t) } - // the type is encoded as "pollerID;pollerType" - sem := strings.LastIndex(tt, ";") - if sem < 0 { - return nil, fmt.Errorf("invalid poller type %s", tt) + ttID, ttKind, err := pollers.DecodeID(tt) + if err != nil { + return nil, err } // ensure poller types match - if received := tt[:sem]; received != pollerID { - return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, received) + if ttID != pollerID { + return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, ttID) } // now rehydrate the poller based on the encoded poller type var lro lroPoller - switch pt := tt[sem+1:]; pt { + switch ttKind { case "async": azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Azure-AsyncOperation poller.") lro = &async.Poller{} @@ -98,7 +96,7 @@ func NewLROPollerFromResumeToken(pollerID string, token string, pl azcore.Pipeli azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Body poller.") lro = &body.Poller{} default: - return nil, fmt.Errorf("unhandled poller type %s", pt) + return nil, fmt.Errorf("unhandled poller type %s", ttKind) } if err = json.Unmarshal([]byte(token), lro); err != nil { return nil, err From 311af89299f0c5b4670ab9f32ef97fc607ab672f Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 14:56:45 -0700 Subject: [PATCH 16/21] consolidate status constants --- sdk/armcore/internal/pollers/async/async.go | 2 +- sdk/armcore/internal/pollers/body/body.go | 15 +++++---------- sdk/armcore/internal/pollers/loc/loc.go | 8 ++++---- sdk/armcore/internal/pollers/pollers.go | 11 ++++++----- sdk/armcore/poller.go | 2 +- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index d1d187ba91a0..307a19168ca2 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -65,7 +65,7 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er return nil, err }*/ // for DELETE/PATCH/POST, provisioning state is optional - state = "InProgress" + state = pollers.StatusInProgress } else if err != nil { return nil, err } diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go index eb3ffad98b22..7a436661af8f 100644 --- a/sdk/armcore/internal/pollers/body/body.go +++ b/sdk/armcore/internal/pollers/body/body.go @@ -13,11 +13,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) -const ( - stateSucceeded = "Succeeded" - stateInProgress = "InProgress" -) - // Applicable returns true if the LRO is using no headers, just provisioning state. // This is only applicable to PATCH and PUT methods and assumes no polling headers. func Applicable(resp *azcore.Response) bool { @@ -42,7 +37,7 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { } // default initial state to InProgress. depending on the HTTP // status code and provisioning state, we might change the value. - curState := stateInProgress + curState := pollers.StatusInProgress provState, err := pollers.GetProvisioningState(resp) if err != nil && !errors.Is(err, pollers.ErrNoBody) && !errors.Is(err, pollers.ErrNoProvisioningState) { return nil, err @@ -55,10 +50,10 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { curState = provState } else if provState == "" { // for a 200, absense of provisioning state indicates success - curState = stateSucceeded + curState = pollers.StatusSucceeded } } else if resp.StatusCode == http.StatusNoContent { - curState = stateSucceeded + curState = pollers.StatusSucceeded } p.CurState = curState return p, nil @@ -77,7 +72,7 @@ func (p *Poller) Done() bool { // Update updates the Poller from the polling response. func (p *Poller) Update(resp *azcore.Response) error { if resp.StatusCode == http.StatusNoContent { - p.CurState = stateSucceeded + p.CurState = pollers.StatusSucceeded return nil } state, err := pollers.GetProvisioningState(resp) @@ -86,7 +81,7 @@ func (p *Poller) Update(resp *azcore.Response) error { return err } else if errors.Is(err, pollers.ErrNoProvisioningState) { // a response body without provisioning state is considered terminal success - state = stateSucceeded + state = pollers.StatusSucceeded } else if err != nil { return err } diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index 07375c37b817..c1d984b5bb25 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -36,7 +36,7 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { p := &Poller{ Type: pollers.MakeID(pollerID, "loc"), PollURL: locURL, - CurState: "InProgress", + CurState: pollers.StatusInProgress, } return p, nil } @@ -67,12 +67,12 @@ func (p *Poller) Update(resp *azcore.Response) error { p.CurState = state } else { // a 200/201 with no provisioning state indicates success - p.CurState = "Succeeded" + p.CurState = pollers.StatusSucceeded } } else if resp.StatusCode == http.StatusNoContent { - p.CurState = "Succeeded" + p.CurState = pollers.StatusSucceeded } else if resp.StatusCode > 399 && resp.StatusCode < 500 { - p.CurState = "Failed" + p.CurState = pollers.StatusFailed } // a 202 falls through, means the LRO is still in progress and we don't check for provisioning state return nil diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 4c372f60f135..cfa103f167b9 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -23,9 +23,10 @@ const ( ) const ( - statusSucceeded = "succeeded" - statusCanceled = "canceled" - statusFailed = "failed" + StatusSucceeded = "Succeeded" + StatusCanceled = "Canceled" + StatusFailed = "Failed" + StatusInProgress = "InProgress" ) // reads the response body into a raw JSON object. @@ -85,12 +86,12 @@ func status(jsonBody map[string]interface{}) string { // IsTerminalState returns true if the LRO's state is terminal. func IsTerminalState(s string) bool { - return strings.EqualFold(s, statusSucceeded) || strings.EqualFold(s, statusFailed) || strings.EqualFold(s, statusCanceled) + return strings.EqualFold(s, StatusSucceeded) || strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) } // Failed returns true if the LRO's state is terminal failure. func Failed(s string) bool { - return strings.EqualFold(s, statusFailed) || strings.EqualFold(s, statusCanceled) + return strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) } // GetStatus returns the LRO's status from the response body. diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index ebb7e260044a..6927406b9f4d 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -294,7 +294,7 @@ func (*nopPoller) FinalGetURL() string { } func (*nopPoller) Status() string { - return "Succeeded" + return pollers.StatusSucceeded } // returns true if the LRO response contains a valid HTTP status code From 7d4da50f61bfa85dc69eeb70ab02e98f81410da4 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 15:00:34 -0700 Subject: [PATCH 17/21] fix code comment --- sdk/armcore/internal/pollers/pollers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index cfa103f167b9..cc88351cc662 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -30,7 +30,7 @@ const ( ) // reads the response body into a raw JSON object. -// returns an empty object if there was no content. +// returns ErrNoBody if there was no content. func getJSON(resp *azcore.Response) (map[string]interface{}, error) { body, err := ioutil.ReadAll(resp.Body) if err != nil { From f38016c6cc09a690fd347f7c03c5aeefb7e90ff2 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 9 Jun 2021 15:09:25 -0700 Subject: [PATCH 18/21] fix closing of response body --- sdk/armcore/internal/pollers/pollers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index cc88351cc662..7952b52b8702 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -33,10 +33,10 @@ const ( // returns ErrNoBody if there was no content. func getJSON(resp *azcore.Response) (map[string]interface{}, error) { body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() if err != nil { return nil, err } - resp.Body.Close() if len(body) == 0 { return nil, ErrNoBody } From 5578d4f0f2e6dbbe089295b2be8e3c7167a9d843 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 10 Jun 2021 09:31:56 -0700 Subject: [PATCH 19/21] add comments to fields --- sdk/armcore/internal/pollers/async/async.go | 25 ++++++++++++++++----- sdk/armcore/internal/pollers/body/body.go | 9 ++++++-- sdk/armcore/internal/pollers/loc/loc.go | 9 ++++++-- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index 307a19168ca2..74f161e948dd 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -27,13 +27,26 @@ func Applicable(resp *azcore.Response) bool { // Poller is an LRO poller that uses the Azure-AsyncOperation pattern. type Poller struct { - Type string `json:"type"` - AsyncURL string `json:"asyncURL"` - LocURL string `json:"locURL"` - OrigURL string `json:"origURL"` - Method string `json:"method"` + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL from Azure-AsyncOperation header. + AsyncURL string `json:"asyncURL"` + + // The URL from Location header. + LocURL string `json:"locURL"` + + // The URL from the initial LRO request. + OrigURL string `json:"origURL"` + + // The HTTP method from the initial LRO request. + Method string `json:"method"` + + // The value of final-state-via from swagger, can be the empty string. FinalState string `json:"finalState"` - CurState string `json:"state"` + + // The LRO's current state. + CurState string `json:"state"` } // New creates a new Poller from the provided initial response and final-state type. diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go index 7a436661af8f..1254580cd449 100644 --- a/sdk/armcore/internal/pollers/body/body.go +++ b/sdk/armcore/internal/pollers/body/body.go @@ -23,8 +23,13 @@ func Applicable(resp *azcore.Response) bool { // Poller is an LRO poller that uses the Body pattern. type Poller struct { - Type string `json:"type"` - PollURL string `json:"pollURL"` + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. CurState string `json:"state"` } diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index c1d984b5bb25..683663681236 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -21,8 +21,13 @@ func Applicable(resp *azcore.Response) bool { // Poller is an LRO poller that uses the Location pattern. type Poller struct { - Type string `json:"type"` - PollURL string `json:"pollURL"` + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. CurState string `json:"state"` } From a7ffaa0369bcea75cd41830d6cedc3e5b9877436 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 10 Jun 2021 09:58:18 -0700 Subject: [PATCH 20/21] simplify error handling for missing states --- sdk/armcore/internal/pollers/async/async.go | 4 +++- .../internal/pollers/async/async_test.go | 20 +++++++++++++++++++ sdk/armcore/internal/pollers/body/body.go | 4 ++-- sdk/armcore/internal/pollers/loc/loc.go | 2 +- sdk/armcore/internal/pollers/pollers.go | 20 ++++--------------- sdk/armcore/internal/pollers/pollers_test.go | 8 ++++---- 6 files changed, 34 insertions(+), 24 deletions(-) diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index 74f161e948dd..7476b55ee76e 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -69,7 +69,7 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er } // check for provisioning state state, err := pollers.GetProvisioningState(resp) - if errors.Is(err, pollers.ErrNoBody) || errors.Is(err, pollers.ErrNoProvisioningState) { + if errors.Is(err, pollers.ErrNoBody) || state == "" { // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST // contain a provisioning state. to maintain compat with track 1 and other implementations // we are explicitly relaxing this requirement. @@ -96,6 +96,8 @@ func (p *Poller) Update(resp *azcore.Response) error { state, err := pollers.GetStatus(resp) if err != nil { return err + } else if state == "" { + return errors.New("the response did not contain a status") } p.CurState = state return nil diff --git a/sdk/armcore/internal/pollers/async/async_test.go b/sdk/armcore/internal/pollers/async/async_test.go index 711285cb2f9a..b798ea16d6b5 100644 --- a/sdk/armcore/internal/pollers/async/async_test.go +++ b/sdk/armcore/internal/pollers/async/async_test.go @@ -164,3 +164,23 @@ func TestNewFinalGetOrigin(t *testing.T) { t.Fatalf("unexpected polling URL %s", u) } } + +func TestNewPutNoProvStateOnUpdate(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(strings.NewReader("{}"))); err == nil { + t.Fatal("unexpected nil error") + } +} diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go index 1254580cd449..55f83739e8b9 100644 --- a/sdk/armcore/internal/pollers/body/body.go +++ b/sdk/armcore/internal/pollers/body/body.go @@ -44,7 +44,7 @@ func New(resp *azcore.Response, pollerID string) (*Poller, error) { // status code and provisioning state, we might change the value. curState := pollers.StatusInProgress provState, err := pollers.GetProvisioningState(resp) - if err != nil && !errors.Is(err, pollers.ErrNoBody) && !errors.Is(err, pollers.ErrNoProvisioningState) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return nil, err } if resp.StatusCode == http.StatusCreated && provState != "" { @@ -84,7 +84,7 @@ func (p *Poller) Update(resp *azcore.Response) error { if errors.Is(err, pollers.ErrNoBody) { // a missing response body in non-204 case is an error return err - } else if errors.Is(err, pollers.ErrNoProvisioningState) { + } else if state == "" { // a response body without provisioning state is considered terminal success state = pollers.StatusSucceeded } else if err != nil { diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index 683663681236..dd6e517852ae 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -65,7 +65,7 @@ func (p *Poller) Update(resp *azcore.Response) error { if resp.HasStatusCode(http.StatusOK, http.StatusCreated) { // if a 200/201 returns a provisioning state, use that instead state, err := pollers.GetProvisioningState(resp) - if err != nil && !errors.Is(err, pollers.ErrNoBody) && !errors.Is(err, pollers.ErrNoProvisioningState) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return err } if state != "" { diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 7952b52b8702..a030e081ae1f 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -96,29 +96,23 @@ func Failed(s string) bool { // GetStatus returns the LRO's status from the response body. // Typically used for Azure-AsyncOperation flows. -// If there is no status in the response body ErrNoStatus is returned. +// If there is no status in the response body the empty string is returned. func GetStatus(resp *azcore.Response) (string, error) { jsonBody, err := getJSON(resp) if err != nil { return "", err } - if s := status(jsonBody); s != "" { - return s, nil - } - return "", ErrNoStatus + return status(jsonBody), nil } // GetProvisioningState returns the LRO's state from the response body. -// If there is no state in the response body ErrNoProvisioningState is returned. +// If there is no state in the response body the empty string is returned. func GetProvisioningState(resp *azcore.Response) (string, error) { jsonBody, err := getJSON(resp) if err != nil { return "", err } - if ps := provisioningState(jsonBody); ps != "" { - return ps, nil - } - return "", ErrNoProvisioningState + return provisioningState(jsonBody), nil } // IsValidURL verifies that the URL is valid and absolute. @@ -152,9 +146,3 @@ func DecodeID(tk string) (string, string, error) { // ErrNoBody is returned if the response didn't contain a body. var ErrNoBody = errors.New("the response did not contain a body") - -// ErrNoStatus is returned if the response body didn't contain a status. -var ErrNoStatus = errors.New("the response did not contain a status") - -// ErrNoProvisioningState is returned if the response body didn't contain a provisioning state. -var ErrNoProvisioningState = errors.New("the response did not contain a provisioning state") diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index b24418952688..0523387daba9 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -75,8 +75,8 @@ func TestGetStatusError(t *testing.T) { }, } status, err := GetStatus(&resp) - if !errors.Is(err, ErrNoStatus) { - t.Fatalf("unexpected error %T", err) + if err != nil { + t.Fatal(err) } if status != "" { t.Fatalf("expected empty status, got %s", status) @@ -106,8 +106,8 @@ func TestGetProvisioningStateError(t *testing.T) { }, } state, err := GetProvisioningState(&resp) - if !errors.Is(err, ErrNoProvisioningState) { - t.Fatalf("unexpected error %T", err) + if err != nil { + t.Fatal(err) } if state != "" { t.Fatalf("expected empty provisioning state, got %s", state) From b693086e2aed504f6856c72536b5d4c4d5cf3091 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Fri, 11 Jun 2021 07:35:23 -0700 Subject: [PATCH 21/21] add check for empty Location header --- sdk/armcore/internal/pollers/loc/loc.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index dd6e517852ae..567dd0ad0c8e 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -35,6 +35,9 @@ type Poller struct { func New(resp *azcore.Response, pollerID string) (*Poller, error) { azcore.Log().Write(azcore.LogLongRunningOperation, "Using Location poller.") locURL := resp.Header.Get(pollers.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } if !pollers.IsValidURL(locURL) { return nil, fmt.Errorf("invalid polling URL %s", locURL) }