diff --git a/sdk/armcore/go.mod b/sdk/armcore/go.mod index 9a8cf2842281..91c7df6b1aeb 100644 --- a/sdk/armcore/go.mod +++ b/sdk/armcore/go.mod @@ -3,6 +3,6 @@ module github.com/Azure/azure-sdk-for-go/sdk/armcore go 1.14 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0 - github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2 + github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 ) diff --git a/sdk/armcore/go.sum b/sdk/armcore/go.sum index 91b3874be2ef..50a97bab95d6 100644 --- a/sdk/armcore/go.sum +++ b/sdk/armcore/go.sum @@ -1,7 +1,7 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0 h1:4HBTI/9UDZN7tsXyB5TYP3xCv5xVHIUTbvHHH2HFxQY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0/go.mod h1:pElNP+u99BvCZD+0jOlhI9OC/NB2IDTOTGZOZH0Qhq8= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0 h1:HG1ggl8L3ZkV/Ydanf7lKr5kkhhPGCpWdnr1J6v7cO4= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2 h1:UC4vfOhW2l0f2QOCQpOxJS4/K6oKFy2tQZE+uWU1MEo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.2/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 h1:vx8McI56N5oLSQu8xa+xdiE0fjQq8W8Zt49vHP8Rygw= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -10,7 +10,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go new file mode 100644 index 000000000000..7476b55ee76e --- /dev/null +++ b/sdk/armcore/internal/pollers/async/async.go @@ -0,0 +1,133 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + finalStateAsync = "azure-async-operation" + finalStateLoc = "location" + finalStateOrig = "original-uri" +) + +// Applicable returns true if the LRO is using Azure-AsyncOperation. +func Applicable(resp *azcore.Response) bool { + return resp.Header.Get(pollers.HeaderAzureAsync) != "" +} + +// Poller is an LRO poller that uses the Azure-AsyncOperation pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL from Azure-AsyncOperation header. + AsyncURL string `json:"asyncURL"` + + // The URL from Location header. + LocURL string `json:"locURL"` + + // The URL from the initial LRO request. + OrigURL string `json:"origURL"` + + // The HTTP method from the initial LRO request. + Method string `json:"method"` + + // The value of final-state-via from swagger, can be the empty string. + FinalState string `json:"finalState"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response and final-state type. +func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Azure-AsyncOperation poller.") + asyncURL := resp.Header.Get(pollers.HeaderAzureAsync) + if asyncURL == "" { + return nil, errors.New("response is missing Azure-AsyncOperation header") + } + if !pollers.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, "async"), + AsyncURL: asyncURL, + LocURL: resp.Header.Get(pollers.HeaderLocation), + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + } + // check for provisioning state + state, err := pollers.GetProvisioningState(resp) + if errors.Is(err, pollers.ErrNoBody) || state == "" { + // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST + // contain a provisioning state. to maintain compat with track 1 and other implementations + // we are explicitly relaxing this requirement. + /*if resp.Request.Method == http.MethodPut { + // initial response for a PUT requires a provisioning state + return nil, err + }*/ + // for DELETE/PATCH/POST, provisioning state is optional + state = pollers.StatusInProgress + } else if err != nil { + return nil, err + } + p.CurState = state + return p, nil +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + state, err := pollers.GetStatus(resp) + if err != nil { + return err + } else if state == "" { + return errors.New("the response did not contain a status") + } + p.CurState = state + return nil +} + +// FinalGetURL returns the URL to perform a final GET for the payload, or the empty string if not required. +func (p *Poller) FinalGetURL() string { + if p.Method == http.MethodPatch || p.Method == http.MethodPut { + // for PATCH and PUT, the final GET is on the original resource URL + return p.OrigURL + } else if p.Method == http.MethodPost { + if p.FinalState == finalStateAsync { + return "" + } else if p.FinalState == finalStateOrig { + return p.OrigURL + } else if p.LocURL != "" { + // ideally FinalState would be set to "location" but it isn't always. + // must check last due to more permissive condition. + return p.LocURL + } + } + return "" +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.AsyncURL +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/async/async_test.go b/sdk/armcore/internal/pollers/async/async_test.go new file mode 100644 index 000000000000..b798ea16d6b5 --- /dev/null +++ b/sdk/armcore/internal/pollers/async/async_test.go @@ -0,0 +1,186 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + }, + } +} + +func pollingResponse(resp io.Reader) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + }, + } + if Applicable(&resp) { + t.Fatal("missing Azure-AsyncOperation should not be applicable") + } + resp.Response.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + if !Applicable(&resp) { + t.Fatal("having Azure-AsyncOperation should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(strings.NewReader(`{ "status": "InProgress" }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewDeleteNoProvState(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewPutNoProvState(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewFinalGetLocation(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(pollers.HeaderLocation, locURL) + poller, err := New(resp, "location", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != locURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFinalGetOrigin(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(pollers.HeaderLocation, locURL) + poller, err := New(resp, "original-uri", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewPutNoProvStateOnUpdate(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(pollers.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(strings.NewReader("{}"))); err == nil { + t.Fatal("unexpected nil error") + } +} diff --git a/sdk/armcore/internal/pollers/body/body.go b/sdk/armcore/internal/pollers/body/body.go new file mode 100644 index 000000000000..55f83739e8b9 --- /dev/null +++ b/sdk/armcore/internal/pollers/body/body.go @@ -0,0 +1,105 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// Applicable returns true if the LRO is using no headers, just provisioning state. +// This is only applicable to PATCH and PUT methods and assumes no polling headers. +func Applicable(resp *azcore.Response) bool { + // we can't check for absense of headers due to some misbehaving services + // like redis that return a Location header but don't actually use that protocol + return resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut +} + +// Poller is an LRO poller that uses the Body pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *azcore.Response, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Body poller.") + p := &Poller{ + Type: pollers.MakeID(pollerID, "body"), + PollURL: resp.Request.URL.String(), + } + // default initial state to InProgress. depending on the HTTP + // status code and provisioning state, we might change the value. + curState := pollers.StatusInProgress + provState, err := pollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, pollers.ErrNoBody) { + return nil, err + } + if resp.StatusCode == http.StatusCreated && provState != "" { + // absense of provisioning state is ok for a 201, means the operation is in progress + curState = provState + } else if resp.StatusCode == http.StatusOK { + if provState != "" { + curState = provState + } else if provState == "" { + // for a 200, absense of provisioning state indicates success + curState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + curState = pollers.StatusSucceeded + } + p.CurState = curState + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + return nil + } + state, err := pollers.GetProvisioningState(resp) + if errors.Is(err, pollers.ErrNoBody) { + // a missing response body in non-204 case is an error + return err + } else if state == "" { + // a response body without provisioning state is considered terminal success + state = pollers.StatusSucceeded + } else if err != nil { + return err + } + p.CurState = state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (*Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/body/body_test.go b/sdk/armcore/internal/pollers/body/body_test.go new file mode 100644 index 000000000000..9cbd5821b028 --- /dev/null +++ b/sdk/armcore/internal/pollers/body/body_test.go @@ -0,0 +1,213 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + }, + } +} + +func pollingResponse(status int, resp io.Reader) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + StatusCode: status, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + Request: &http.Request{ + Method: http.MethodDelete, + }, + }, + } + if Applicable(&resp) { + t.Fatal("method DELETE should not be applicable") + } + resp.Request.Method = http.MethodPatch + if !Applicable(&resp) { + t.Fatal("method PATCH should be applicable") + } + resp.Request.Method = http.MethodPut + if !Applicable(&resp) { + t.Fatal("method PUT should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusCreated + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateNoProvStateFail(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, http.NoBody)) + if err == nil { + t.Fatal("unexpected nil error") + } + if !errors.Is(err, pollers.ErrNoBody) { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestUpdateNoProvStateSuccess(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{}`))) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateNoProvState204(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)) + if err != nil { + t.Fatal(err) + } +} + +func TestNewNoInitialProvStateOK(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewNoInitialProvStateNC(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusNoContent + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..567dd0ad0c8e --- /dev/null +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -0,0 +1,97 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *azcore.Response) bool { + return resp.StatusCode == http.StatusAccepted && resp.Header.Get(pollers.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *azcore.Response, pollerID string) (*Poller, error) { + azcore.Log().Write(azcore.LogLongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(pollers.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, "loc"), + PollURL: locURL, + CurState: pollers.StatusInProgress, + } + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *azcore.Response) error { + // location polling can return an updated polling URL + if h := resp.Header.Get(pollers.HeaderLocation); h != "" { + p.PollURL = h + } + if resp.HasStatusCode(http.StatusOK, http.StatusCreated) { + // if a 200/201 returns a provisioning state, use that instead + state, err := pollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, pollers.ErrNoBody) { + return err + } + if state != "" { + p.CurState = state + } else { + // a 200/201 with no provisioning state indicates success + p.CurState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + } else if resp.StatusCode > 399 && resp.StatusCode < 500 { + p.CurState = pollers.StatusFailed + } + // a 202 falls through, means the LRO is still in progress and we don't check for provisioning state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (p *Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/armcore/internal/pollers/loc/loc_test.go b/sdk/armcore/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..ce14b9393bf3 --- /dev/null +++ b/sdk/armcore/internal/pollers/loc/loc_test.go @@ -0,0 +1,140 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + fakePollingURL1 = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/updated" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + }, + } +} + +func pollingResponse(statusCode int, body io.Reader) *azcore.Response { + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + StatusCode: statusCode, + }, + } +} + +func TestApplicable(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + }, + } + if Applicable(&resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Response.Header.Set(pollers.HeaderLocation, fakePollingURL1) + if !Applicable(&resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(pollers.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(pollers.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusConflict, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Failed" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateWithProvState(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(pollers.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(pollers.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "Updating" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusOK, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go new file mode 100644 index 000000000000..a030e081ae1f --- /dev/null +++ b/sdk/armcore/internal/pollers/pollers.go @@ -0,0 +1,148 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderLocation = "Location" +) + +const ( + StatusSucceeded = "Succeeded" + StatusCanceled = "Canceled" + StatusFailed = "Failed" + StatusInProgress = "InProgress" +) + +// reads the response body into a raw JSON object. +// returns ErrNoBody if there was no content. +func getJSON(resp *azcore.Response) (map[string]interface{}, error) { + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return nil, err + } + if len(body) == 0 { + return nil, ErrNoBody + } + // put the body back so it's available to others + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + // unmarshall the body to get the value + var jsonBody map[string]interface{} + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + +// provisioningState returns the provisioning state from the response or the empty string. +func provisioningState(jsonBody map[string]interface{}) string { + jsonProps, ok := jsonBody["properties"] + if !ok { + return "" + } + props, ok := jsonProps.(map[string]interface{}) + if !ok { + return "" + } + rawPs, ok := props["provisioningState"] + if !ok { + return "" + } + ps, ok := rawPs.(string) + if !ok { + return "" + } + return ps +} + +// status returns the status from the response or the empty string. +func status(jsonBody map[string]interface{}) string { + rawStatus, ok := jsonBody["status"] + if !ok { + return "" + } + status, ok := rawStatus.(string) + if !ok { + return "" + } + return status +} + +// IsTerminalState returns true if the LRO's state is terminal. +func IsTerminalState(s string) bool { + return strings.EqualFold(s, StatusSucceeded) || strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// Failed returns true if the LRO's state is terminal failure. +func Failed(s string) bool { + return strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// GetStatus returns the LRO's status from the response body. +// Typically used for Azure-AsyncOperation flows. +// If there is no status in the response body the empty string is returned. +func GetStatus(resp *azcore.Response) (string, error) { + jsonBody, err := getJSON(resp) + if err != nil { + return "", err + } + return status(jsonBody), nil +} + +// GetProvisioningState returns the LRO's state from the response body. +// If there is no state in the response body the empty string is returned. +func GetProvisioningState(resp *azcore.Response) (string, error) { + jsonBody, err := getJSON(resp) + if err != nil { + return "", err + } + return provisioningState(jsonBody), nil +} + +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + +const idSeparator = ";" + +// MakeID returns the poller ID from the provided values. +func MakeID(pollerID string, kind string) string { + return fmt.Sprintf("%s%s%s", pollerID, idSeparator, kind) +} + +// DecodeID decodes the poller ID, returning [pollerID, kind] or an error. +func DecodeID(tk string) (string, string, error) { + raw := strings.Split(tk, idSeparator) + // strings.Split will include any/all whitespace strings, we want to omit those + parts := []string{} + for _, r := range raw { + if s := strings.TrimSpace(r); s != "" { + parts = append(parts, s) + } + } + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid token %s", tk) + } + return parts[0], parts[1], nil +} + +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go new file mode 100644 index 000000000000..0523387daba9 --- /dev/null +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -0,0 +1,184 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +func TestIsTerminalState(t *testing.T) { + if IsTerminalState("Updating") { + t.Fatal("Updating is not a terminal state") + } + if !IsTerminalState("Succeeded") { + t.Fatal("Succeeded is a terminal state") + } + if !IsTerminalState("failed") { + t.Fatal("failed is a terminal state") + } + if !IsTerminalState("canceled") { + t.Fatal("canceled is a terminal state") + } +} + +func TestGetStatusSuccess(t *testing.T) { + const jsonBody = `{ "status": "InProgress" }` + resp := azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + }, + } + status, err := GetStatus(&resp) + if err != nil { + t.Fatal(err) + } + if status != "InProgress" { + t.Fatalf("unexpected status %s", status) + } +} + +func TestGetNoBody(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: http.NoBody, + }, + } + status, err := GetStatus(&resp) + if !errors.Is(err, ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } + status, err = GetProvisioningState(&resp) + if !errors.Is(err, ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } +} + +func TestGetStatusError(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, + } + status, err := GetStatus(&resp) + if err != nil { + t.Fatal(err) + } + if status != "" { + t.Fatalf("expected empty status, got %s", status) + } +} + +func TestGetProvisioningState(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Canceled" } }` + resp := azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + }, + } + state, err := GetProvisioningState(&resp) + if err != nil { + t.Fatal(err) + } + if state != "Canceled" { + t.Fatalf("unexpected status %s", state) + } +} + +func TestGetProvisioningStateError(t *testing.T) { + resp := azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, + } + state, err := GetProvisioningState(&resp) + if err != nil { + t.Fatal(err) + } + if state != "" { + t.Fatalf("expected empty provisioning state, got %s", state) + } +} + +func TestMakeID(t *testing.T) { + const ( + pollerID = "pollerID" + kind = "kind" + ) + id := MakeID(pollerID, kind) + parts := strings.Split(id, idSeparator) + if l := len(parts); l != 2 { + t.Fatalf("unexpected length %d", l) + } + if p := parts[0]; p != pollerID { + t.Fatalf("unexpected poller ID %s", p) + } + if p := parts[1]; p != kind { + t.Fatalf("unexpected poller kind %s", p) + } +} + +func TestDecodeID(t *testing.T) { + id, kind, err := DecodeID("") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid_token;") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID(" ;invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("invalid;token;too") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err = DecodeID("pollerID;kind") + if err != nil { + t.Fatal(err) + } + if id != "pollerID" { + t.Fatalf("unexpected ID %s", id) + } + if kind != "kind" { + t.Fatalf("unexpected kin %s", kind) + } +} + +func TestIsValidURL(t *testing.T) { + if IsValidURL("/foo") { + t.Fatal("unexpected valid URL") + } + if !IsValidURL("https://foo.bar/baz") { + t.Fatal("expected valid URL") + } +} + +func TestFailed(t *testing.T) { + if Failed("Succeeded") || Failed("Updating") { + t.Fatal("unexpected failure") + } + if !Failed("failed") { + t.Fatal("expected failure") + } +} diff --git a/sdk/armcore/poller.go b/sdk/armcore/poller.go index bbffe09c5c69..6927406b9f4d 100644 --- a/sdk/armcore/poller.go +++ b/sdk/armcore/poller.go @@ -6,212 +6,227 @@ package armcore import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" "io/ioutil" "net/http" - "net/url" - "reflect" - "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) -const ( - headerAsyncOperation = "Azure-AsyncOperation" - headerLocation = "Location" -) - -const ( - operationInProgress string = "InProgress" - operationCanceled string = "Canceled" - operationFailed string = "Failed" - operationSucceeded string = "Succeeded" -) +// ErrorUnmarshaller is the func to invoke when the endpoint returns an error response that requires unmarshalling. +type ErrorUnmarshaller func(*azcore.Response) error -var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK} - -// NewPoller creates a polling tracker based on the verb of the original request and returns -// the polling tracker implementation for the method verb or an error. +// NewLROPoller creates an LROPoller based on the provided initial response. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPoller(pollerType string, finalState string, resp *azcore.Response, errorHandler methodErrorHandler) (Poller, error) { - var pt pollingTracker - switch strings.ToUpper(resp.Request.Method) { - case http.MethodDelete: - pt = &pollingTrackerDelete{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPatch: - pt = &pollingTrackerPatch{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPost: - pt = &pollingTrackerPost{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - case http.MethodPut: - pt = &pollingTrackerPut{pollingTrackerBase: newPollingTrackerBase(pollerType, finalState, resp, errorHandler)} - default: - return nil, fmt.Errorf("unsupported HTTP method %s", resp.Request.Method) - } - if err := pt.initializeState(); err != nil { - return nil, err +func NewLROPoller(pollerID string, finalState string, resp *azcore.Response, pl azcore.Pipeline, eu ErrorUnmarshaller) (*LROPoller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !lroStatusCodeValid(resp) { + return nil, errors.New("the LRO failed or was cancelled") + } + // determine the polling method + var lro lroPoller + var err error + if async.Applicable(resp) { + lro, err = async.New(resp, finalState, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else if body.Applicable(resp) { + // must test body poller last as it's a subset of the other pollers. + // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) + lro, err = body.New(resp, pollerID) + } else if m := resp.Request.Method; resp.StatusCode == http.StatusAccepted && (m == http.MethodDelete || m == http.MethodPost) { + // if we get here it means we have a 202 with no polling headers. + // for DELETE and POST this is a hard error per ARM RPC spec. + return nil, errors.New("response is missing polling URL") + } else { + lro = &nopPoller{} } - // this initializes the polling header values, we do this during creation in case the - // initial response send us invalid values; this way the API call will return a non-nil - // error (not doing this means the error shows up in Future.Done) - if err := pt.updatePollingMethod(); err != nil { + if err != nil { return nil, err } - return &poller{pt: pt}, nil + return &LROPoller{lro: lro, Pipeline: pl, eu: eu, resp: resp}, nil } -// NewPollerFromResumeToken creates a polling tracker from a resume token string. +// NewLROPollerFromResumeToken creates an LROPoller from a resume token string. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPollerFromResumeToken(pollerType string, token string, errorHandler methodErrorHandler) (Poller, error) { - // unmarshal into JSON object to determine the tracker type +func NewLROPollerFromResumeToken(pollerID string, token string, pl azcore.Pipeline, eu ErrorUnmarshaller) (*LROPoller, error) { + // unmarshal into JSON object to determine the poller type obj := map[string]interface{}{} err := json.Unmarshal([]byte(token), &obj) if err != nil { return nil, err } - if obj["pollerType"] != pollerType { - return nil, fmt.Errorf("Cannot resume from this poller type. Expected: %s, Received: %s", pollerType, obj["pollerType"]) + t, ok := obj["type"] + if !ok { + return nil, errors.New("missing type field") } - if obj["method"] == nil { - return nil, fmt.Errorf("Token is missing 'method' property") + tt, ok := t.(string) + if !ok { + return nil, fmt.Errorf("invalid type format %T", t) } - var pt pollingTracker - switch method := strings.ToUpper(obj["method"].(string)); method { - case http.MethodDelete: - pt = &pollingTrackerDelete{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPatch: - pt = &pollingTrackerPatch{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPost: - pt = &pollingTrackerPost{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} - case http.MethodPut: - pt = &pollingTrackerPut{pollingTrackerBase: pollingTrackerBase{errorHandler: errorHandler}} + ttID, ttKind, err := pollers.DecodeID(tt) + if err != nil { + return nil, err + } + // ensure poller types match + if ttID != pollerID { + return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, ttID) + } + // now rehydrate the poller based on the encoded poller type + var lro lroPoller + switch ttKind { + case "async": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Azure-AsyncOperation poller.") + lro = &async.Poller{} + case "loc": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Location poller.") + lro = &loc.Poller{} + case "body": + azcore.Log().Write(azcore.LogLongRunningOperation, "Resuming Body poller.") + lro = &body.Poller{} default: - return nil, fmt.Errorf("unsupported method '%s'", method) + return nil, fmt.Errorf("unhandled poller type %s", ttKind) } - // now unmarshal into the tracker - err = json.Unmarshal([]byte(token), &pt) - if err != nil { + if err = json.Unmarshal([]byte(token), lro); err != nil { return nil, err } - return &poller{pt: pt}, nil + return &LROPoller{lro: lro, Pipeline: pl, eu: eu}, nil } -// Poller defines the methods that will be called internally in the generated code for long-running operations. +// LROPoller encapsulates state and logic for polling on long-running operations. // NOTE: this is only meant for internal use in generated code. -type Poller interface { - // Done signals if the polling operation has reached a terminal state. - Done() bool - // Poll sends a polling request to the service endpoint and returns the http.Response received from the endpoint or an error. - Poll(ctx context.Context, p azcore.Pipeline) (*http.Response, error) - // FinalResponse will perform a final GET and return the final http response for the polling operation and unmarshal the content of the payload into the respType interface that is provided. - FinalResponse(ctx context.Context, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) - // ResumeToken returns a token string that can be used to resume polling on a poller that has not yet reached a terminal state. - ResumeToken() (string, error) - // PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, then return the final http response for the polling operation and unmarshal the content of the payload into the respType interface that is provided. - PollUntilDone(ctx context.Context, frequency time.Duration, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) +type LROPoller struct { + Pipeline azcore.Pipeline + lro lroPoller + eu ErrorUnmarshaller + resp *azcore.Response + err error } -type poller struct { - pt pollingTracker -} - -func (p *poller) Done() bool { - return p.pt.hasTerminated() +// Done returns true if the LRO has reached a terminal state. +func (l *LROPoller) Done() bool { + if l.err != nil { + return true + } + return l.lro.Done() } -func (p *poller) FinalResponse(ctx context.Context, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) { - if !p.pt.hasTerminated() { - return nil, errors.New("cannot return a final response from a poller in a non-terminal state") - } - // if respType is nil, this indicates that the request was made from an HTTPPoller - if respType == nil { - return p.pt.latestResponse().Response, nil - } - if p.pt.pollerMethodVerb() == http.MethodPut || p.pt.pollerMethodVerb() == http.MethodPatch { - res, err := p.handleResponse(p.pt.latestResponse(), respType) - if err != nil { - return nil, err - } - if res != nil && !reflect.Indirect(reflect.ValueOf(respType)).IsZero() { - return res, nil +// Poll sends a polling request to the polling endpoint and returns the response or error. +func (l *LROPoller) Poll(ctx context.Context) (*http.Response, error) { + if l.Done() { + // the LRO has reached a terminal state, don't poll again + if l.resp != nil { + return l.resp.Response, nil } + return nil, l.err } - azcore.Log().Writef(azcore.LogLongRunningOperation, "performing final GET for %s", p.pt.pollerType()) - // checking if there was a FinalStateVia configuration to re-route the final GET - // request to the value specified in the FinalStateVia property on the poller - err := p.pt.setFinalState() + req, err := azcore.NewRequest(ctx, http.MethodGet, l.lro.URL()) if err != nil { return nil, err } - if p.pt.finalGetURL() == "" { - // we can end up in this situation if the async operation returns a 200 - // with no polling URLs. in that case return the response which should - // contain the JSON payload (only do this for successful terminal cases). - if lr := p.pt.latestResponse(); lr != nil && p.pt.hasSucceeded() { - result, err := p.handleResponse(lr, respType) - if err != nil { - return nil, err - } - return result, nil - } - return nil, errors.New("missing URL for retrieving result") - } - req, err := azcore.NewRequest(ctx, http.MethodGet, p.pt.finalGetURL()) + resp, err := l.Pipeline.Do(req) if err != nil { + // don't update the poller for failed requests return nil, err } - resp, err := pipeline.Do(req) - if err != nil { + defer resp.Body.Close() + if !lroStatusCodeValid(resp) { + // the LRO failed. unmarshall the error and update state + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + if err = l.lro.Update(resp); err != nil { return nil, err } - return p.handleResponse(resp, respType) + l.resp = resp + azcore.Log().Writef(azcore.LogLongRunningOperation, "Status %s", l.lro.Status()) + if pollers.Failed(l.lro.Status()) { + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + return l.resp.Response, nil } -func (p *poller) ResumeToken() (string, error) { - if p.pt.hasTerminated() { +// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. +func (l *LROPoller) ResumeToken() (string, error) { + if l.Done() { return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") } - js, err := json.Marshal(p.pt) + b, err := json.Marshal(l.lro) if err != nil { return "", err } - return string(js), nil + return string(b), nil } -func (p *poller) handleResponse(resp *azcore.Response, respType interface{}) (*http.Response, error) { - if resp.HasStatusCode(http.StatusNoContent) { - return resp.Response, nil +// FinalResponse will perform a final GET request and return the final HTTP response for the polling +// operation and unmarshall the content of the payload into the respType interface that is provided. +func (l *LROPoller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { + if !l.Done() { + return nil, errors.New("cannot return a final response from a poller in a non-terminal state") } - if !resp.HasStatusCode(pollingCodes[:]...) { - return nil, p.pt.handleError(resp) + // if there's nothing to unmarshall into or no response body just return the final response + if respType == nil { + return l.resp.Response, nil + } else if l.resp.StatusCode == http.StatusNoContent || l.resp.ContentLength == 0 { + azcore.Log().Write(azcore.LogLongRunningOperation, "final response specifies a response type but no payload was received") + return l.resp.Response, nil + } + if u := l.lro.FinalGetURL(); u != "" { + azcore.Log().Write(azcore.LogLongRunningOperation, "Performing final GET.") + req, err := azcore.NewRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + resp, err := l.Pipeline.Do(req) + if err != nil { + return nil, err + } + if !lroStatusCodeValid(resp) { + return nil, l.eu(resp) + } + l.resp = resp } - return resp.Response, resp.UnmarshalAsJSON(respType) -} - -func (p *poller) Poll(ctx context.Context, pipeline azcore.Pipeline) (*http.Response, error) { - if p.pollForStatus(ctx, pipeline, p.pt) { - return p.pt.latestResponse().Response, p.pt.pollingError() + body, err := ioutil.ReadAll(l.resp.Body) + l.resp.Body.Close() + if err != nil { + return nil, err + } + if err = json.Unmarshal(body, respType); err != nil { + return nil, err } - return nil, p.pt.pollingError() + return l.resp.Response, nil } -func (p *poller) PollUntilDone(ctx context.Context, frequency time.Duration, pipeline azcore.Pipeline, respType interface{}) (*http.Response, error) { +// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, +// then return the final HTTP response for the polling operation and unmarshal the content of the payload +// into the respType interface that is provided. +// freq - the time to wait between polling intervals if the endpoint doesn't send a Retry-After header. +// A good starting value is 30 seconds. Note that some resources might benefit from a different value. +func (l *LROPoller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { + start := time.Now() logPollUntilDoneExit := func(v interface{}) { - azcore.Log().Writef(azcore.LogLongRunningOperation, "END PollUntilDone() for %s: %v", p.pt.pollerType(), v) + azcore.Log().Writef(azcore.LogLongRunningOperation, "END PollUntilDone() for %T: %v, total time: %s", l.lro, v, time.Now().Sub(start)) } - azcore.Log().Writef(azcore.LogLongRunningOperation, "BEGIN PollUntilDone() for %s", p.pt.pollerType()) - // p.resp should only be nil when calling PollUntilDone from a poller that was instantiated from a resume token string - if resp := p.pt.latestResponse(); resp != nil { + azcore.Log().Writef(azcore.LogLongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) + if l.resp != nil { // initial check for a retry-after header existing on the initial response - if retryAfter := azcore.RetryAfter(resp.Response); retryAfter > 0 { + if retryAfter := azcore.RetryAfter(l.resp.Response); retryAfter > 0 { azcore.Log().Writef(azcore.LogLongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) - err := delay(ctx, retryAfter) - if err != nil { + if err := delay(ctx, retryAfter); err != nil { logPollUntilDoneExit(err) return nil, err } @@ -219,745 +234,79 @@ func (p *poller) PollUntilDone(ctx context.Context, frequency time.Duration, pip } // begin polling the endpoint until a terminal state is reached for { - resp, err := p.Poll(ctx, pipeline) + resp, err := l.Poll(ctx) if err != nil { logPollUntilDoneExit(err) return nil, err } - if p.pt.hasTerminated() { - break + if l.Done() { + logPollUntilDoneExit(l.lro.Status()) + return l.FinalResponse(ctx, respType) } - d := frequency + d := freq if retryAfter := azcore.RetryAfter(resp); retryAfter > 0 { azcore.Log().Writef(azcore.LogLongRunningOperation, "Retry-After delay for %s", retryAfter.String()) d = retryAfter } else { azcore.Log().Writef(azcore.LogLongRunningOperation, "delay for %s", d.String()) } - err = delay(ctx, d) - if err != nil { + if err = delay(ctx, d); err != nil { logPollUntilDoneExit(err) return nil, err } } - logPollUntilDoneExit(p.pt.pollingStatus()) - return p.FinalResponse(ctx, pipeline, respType) -} - -func delay(ctx context.Context, delay time.Duration) error { - select { - case <-time.After(delay): - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// pollForStatus queries the service to see if the operation has completed. -func (p *poller) pollForStatus(ctx context.Context, pl azcore.Pipeline, pt pollingTracker) bool { - if pt.hasTerminated() { - return true - } - if err := pt.pollForStatus(ctx, pl); err != nil { - return true - } - if err := pt.checkForErrors(); err != nil { - return true - } - if err := pt.updatePollingState(pt.provisioningStateApplicable()); err != nil { - return true - } - if err := pt.initPollingMethod(); err != nil { - return true - } - if err := pt.updatePollingMethod(); err != nil { - return true - } - // VERB status method URL - azcore.Log().Writef(azcore.LogLongRunningOperation, "%s %s %s %s", - strings.ToUpper(pt.pollerMethodVerb()), pt.pollingStatus(), pt.pollingMethod(), pt.pollingURL()) - return pt.hasTerminated() } -type pollingTracker interface { - // these methods can differ per tracker - - // checks the response headers and status code to determine the polling mechanism - updatePollingMethod() error - - // checks the response for tracker-specific error conditions - checkForErrors() error +var _ azcore.Poller = (*LROPoller)(nil) - // returns true if provisioning state should be checked - provisioningStateApplicable() bool - - // methods common to all trackers - - // initializes a tracker's polling URL and method, called for each iteration. - // these values can be overridden by each polling tracker as required. - initPollingMethod() error - - // initializes the tracker's internal state, call this when the tracker is created - initializeState() error - - // makes an HTTP request to check the status of the LRO - pollForStatus(ctx context.Context, client azcore.Pipeline) error - - // updates internal tracker state, call this after each call to pollForStatus - updatePollingState(provStateApl bool) error - - // returns the error response from the service, can be nil - pollingError() error - - // returns the polling method being used - pollingMethod() pollingMethodType - - // returns the state of the LRO as returned from the service - pollingStatus() string - - // returns the URL used for polling status - pollingURL() string - - // returns the client poller type - pollerType() string - - // returns the URL used for the final GET to retrieve the resource - finalGetURL() string - - // returns true if the LRO is in a terminal state - hasTerminated() bool - - // returns true if the LRO is in a failed terminal state - hasFailed() bool - - // returns true if the LRO is in a successful terminal state - hasSucceeded() bool - - // returns the cached HTTP response after a call to pollForStatus(), can be nil - latestResponse() *azcore.Response - - // converts an *azcore.Response to an error - handleError(resp *azcore.Response) error - - // sets the FinalGetURI to the value pointed to in FinalStateVia - setFinalState() error - - // returns the verb used with the initial request - pollerMethodVerb() string -} - -type methodErrorHandler func(resp *azcore.Response) error - -type pollingTrackerBase struct { - // resp is the last response, either from the submission of the LRO or from polling - resp *azcore.Response - - // PollerType is the name of the type of poller that is created - PollerType string `json:"pollerType"` - - // errorHandler is the method to invoke to unmarshall an error response - errorHandler methodErrorHandler - - // method is the HTTP verb, this is needed for deserialization - Method string `json:"method"` - - // rawBody is the raw JSON response body - rawBody map[string]interface{} - - // denotes if polling is using async-operation or location header - Pm pollingMethodType `json:"pollingMethod"` - - // the URL to poll for status - URI string `json:"pollingURI"` - - // the state of the LRO as returned from the service - State string `json:"lroState"` - - // the URL to GET for the final result - FinalGetURI string `json:"resultURI"` - - // stores the name of the header that the final get should be performed on, - // can be empty which will go to default behavior - FinalStateVia string `json:"finalStateVia"` - - // the original request URL of the initial request for the polling operation - OriginalURI string `json:"originalURI"` - - // used to hold an error object returned from the service - Err error `json:"error,omitempty"` -} - -func newPollingTrackerBase(pollerType, finalState string, resp *azcore.Response, errorHandler methodErrorHandler) pollingTrackerBase { - return pollingTrackerBase{ - PollerType: pollerType, - FinalStateVia: finalState, - OriginalURI: resp.Request.URL.String(), - resp: resp, - errorHandler: errorHandler, - } +// abstracts the differences between concrete poller types +type lroPoller interface { + Done() bool + Update(resp *azcore.Response) error + FinalGetURL() string + URL() string + Status() string } -func (pt *pollingTrackerBase) initializeState() error { - // determine the initial polling state based on response body and/or HTTP status - // code. this is applicable to the initial LRO response, not polling responses! - pt.Method = pt.resp.Request.Method - if err := pt.updateRawBody(); err != nil { - pt.Err = err - return err - } - switch pt.resp.StatusCode { - case http.StatusOK: - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - if pt.hasFailed() { - pt.updateErrorFromResponse() - return pt.pollingError() - } - } else { - pt.State = operationSucceeded - } - case http.StatusCreated: - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - } else { - pt.State = operationInProgress - } - case http.StatusAccepted: - pt.State = operationInProgress - case http.StatusNoContent: - pt.State = operationSucceeded - default: - pt.State = operationFailed - pt.updateErrorFromResponse() - return pt.pollingError() - } - return pt.initPollingMethod() -} +// ==================================================================================================== -func (pt *pollingTrackerBase) getProvisioningState() *string { - if pt.rawBody != nil && pt.rawBody["properties"] != nil { - p := pt.rawBody["properties"].(map[string]interface{}) - if ps := p["provisioningState"]; ps != nil { - s := ps.(string) - return &s - } - } - return nil -} +// used if the operation synchronously completed +type nopPoller struct{} -func (pt *pollingTrackerBase) updateRawBody() error { - pt.rawBody = map[string]interface{}{} - if pt.resp.ContentLength != 0 { - defer pt.resp.Body.Close() - b, err := ioutil.ReadAll(pt.resp.Body) - if err != nil { - pt.Err = err - return pt.Err - } - // seek back to the beginning of the body or reassign the information to the body - if seeker, ok := pt.resp.Body.(io.Seeker); ok { - _, err = seeker.Seek(0, io.SeekStart) - if err != nil { - pt.Err = err - return pt.Err - } - } else { - // put the body back so it's available to other callers - pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - // observed in 204 responses over HTTP/2.0; the content length is -1 but body is empty - if len(b) == 0 { - return nil - } - if err = json.Unmarshal(b, &pt.rawBody); err != nil { - pt.Err = err - return pt.Err - } - } - return nil +func (*nopPoller) URL() string { + return "" } -func (pt *pollingTrackerBase) pollForStatus(ctx context.Context, client azcore.Pipeline) error { - req, err := azcore.NewRequest(ctx, http.MethodGet, pt.URI) - if err != nil { - pt.Err = err - return err - } - resp, err := client.Do(req) - pt.resp = resp - if err != nil { - pt.Err = err - return pt.Err - } - if pt.resp.HasStatusCode(pollingCodes[:]...) { - // reset the service error on success case - pt.Err = nil - err = pt.updateRawBody() - } else { - // check response body for error content - pt.updateErrorFromResponse() - err = pt.pollingError() - } - return err +func (*nopPoller) Done() bool { + return true } -// attempts to unmarshal a ServiceError type from the response body. -// if that fails then make a best attempt at creating something meaningful. -// NOTE: this assumes that the async operation has failed. -func (pt *pollingTrackerBase) updateErrorFromResponse() { - pt.Err = pt.errorHandler(pt.resp) +func (*nopPoller) Succeeded() bool { + return true } -func (pt *pollingTrackerBase) updatePollingState(provStateApl bool) error { - if pt.Pm == pollingAsyncOperation && pt.rawBody["status"] != nil { - pt.State = pt.rawBody["status"].(string) - } else { - if pt.resp.StatusCode == http.StatusAccepted { - pt.State = operationInProgress - } else if provStateApl { - if ps := pt.getProvisioningState(); ps != nil { - pt.State = *ps - } else { - pt.State = operationSucceeded - } - } else { - pt.Err = fmt.Errorf("the response from the async operation has an invalid status code: %d", pt.resp.StatusCode) - return pt.Err - } - } - // if the operation has failed update the error state - if pt.hasFailed() { - pt.updateErrorFromResponse() - } +func (*nopPoller) Update(*azcore.Response) error { return nil } -func (pt *pollingTrackerBase) pollingError() error { - return pt.Err -} - -func (pt *pollingTrackerBase) pollingMethod() pollingMethodType { - return pt.Pm +func (*nopPoller) FinalGetURL() string { + return "" } -func (pt *pollingTrackerBase) pollingStatus() string { - return pt.State +func (*nopPoller) Status() string { + return pollers.StatusSucceeded } -func (pt *pollingTrackerBase) pollingURL() string { - return pt.URI +// returns true if the LRO response contains a valid HTTP status code +func lroStatusCodeValid(resp *azcore.Response) bool { + return resp.HasStatusCode(http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) } -func (pt *pollingTrackerBase) pollerType() string { - return pt.PollerType -} - -func (pt *pollingTrackerBase) finalGetURL() string { - return pt.FinalGetURI -} - -func (pt *pollingTrackerBase) hasTerminated() bool { - return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) || strings.EqualFold(pt.State, operationSucceeded) -} - -func (pt *pollingTrackerBase) hasFailed() bool { - return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) -} - -func (pt *pollingTrackerBase) hasSucceeded() bool { - return strings.EqualFold(pt.State, operationSucceeded) -} - -func (pt *pollingTrackerBase) latestResponse() *azcore.Response { - return pt.resp -} - -// error checking common to all trackers -func (pt *pollingTrackerBase) baseCheckForErrors() error { - // for Azure-AsyncOperations the response body cannot be nil or empty - if pt.Pm == pollingAsyncOperation { - if pt.resp.Body == nil || pt.resp.ContentLength == 0 { - pt.Err = errors.New("for Azure-AsyncOperation response body cannot be nil") - return pt.Err - } - if pt.rawBody["status"] == nil { - pt.Err = errors.New("missing status property in Azure-AsyncOperation response body") - return pt.Err - } - } - return nil -} - -// default initialization of polling URL/method. each verb tracker will update this as required. -func (pt *pollingTrackerBase) initPollingMethod() error { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - return nil - } - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh != "" { - pt.URI = lh - pt.Pm = pollingLocation - return nil - } - // it's ok if we didn't find a polling header, this will be handled elsewhere - return nil -} - -func (pt *pollingTrackerBase) handleError(resp *azcore.Response) error { - return pt.errorHandler(resp) -} - -func (pt *pollingTrackerBase) setFinalState() error { - if len(pt.FinalStateVia) == 0 { +func delay(ctx context.Context, delay time.Duration) error { + select { + case <-time.After(delay): return nil + case <-ctx.Done(): + return ctx.Err() } - if pt.FinalStateVia == "azure-async-operation" { - ao, err := getURLFromAsyncOpHeader(pt.latestResponse()) - if err != nil { - return err - } - if ao != "" { - pt.FinalGetURI = ao - } - } else if pt.FinalStateVia == "location" { - lh, err := getURLFromLocationHeader(pt.latestResponse()) - if err != nil { - return err - } - if lh != "" { - pt.FinalGetURI = lh - } - } else if pt.FinalStateVia == "original-uri" { - pt.FinalGetURI = pt.OriginalURI - } - return nil -} - -func (pt *pollingTrackerBase) pollerMethodVerb() string { - return pt.Method -} - -// DELETE - -type pollingTrackerDelete struct { - pollingTrackerBase -} - -func (pt *pollingTrackerDelete) updatePollingMethod() error { - // for 201 the Location header is required - if pt.resp.StatusCode == http.StatusCreated { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("missing Location header in 201 response") - return pt.Err - } else { - pt.URI = lh - } - pt.Pm = pollingLocation - pt.FinalGetURI = pt.URI - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - // when both headers are returned we use the value in the Location header for the final GET - pt.FinalGetURI = lh - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerDelete) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerDelete) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent -} - -// PATCH - -type pollingTrackerPatch struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPatch) updatePollingMethod() error { - // by default we can use the original URL for polling and final GET - if pt.URI == "" { - pt.URI = pt.resp.Request.URL.String() - } - if pt.FinalGetURI == "" { - pt.FinalGetURI = pt.resp.Request.URL.String() - } - if pt.Pm == pollingUnknown { - pt.Pm = pollingRequestURI - } - // for 201 it's permissible for no headers to be returned - if pt.resp.StatusCode == http.StatusCreated { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - // note the absence of the "final GET" mechanism for PATCH - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - if ao == "" { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } else { - pt.URI = lh - pt.Pm = pollingLocation - } - } - } - return nil -} - -func (pt *pollingTrackerPatch) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerPatch) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated -} - -// POST - -type pollingTrackerPost struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPost) updatePollingMethod() error { - // 201 requires Location header - if pt.resp.StatusCode == http.StatusCreated { - if lh, err := getURLFromLocationHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if lh == "" { - pt.Err = errors.New("missing Location header in 201 response") - return pt.Err - } else { - pt.URI = lh - pt.FinalGetURI = lh - pt.Pm = pollingLocation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - // when both headers are returned we use the value in the Location header for the final GET - pt.FinalGetURI = lh - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerPost) checkForErrors() error { - return pt.baseCheckForErrors() -} - -func (pt *pollingTrackerPost) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent -} - -// PUT - -type pollingTrackerPut struct { - pollingTrackerBase -} - -func (pt *pollingTrackerPut) updatePollingMethod() error { - // by default we can use the original URL for polling and final GET - if pt.URI == "" { - pt.URI = pt.resp.Request.URL.String() - } - if pt.FinalGetURI == "" { - pt.FinalGetURI = pt.resp.Request.URL.String() - } - if pt.Pm == pollingUnknown { - pt.Pm = pollingRequestURI - } - // for 201 it's permissible for no headers to be returned - if pt.resp.StatusCode == http.StatusCreated { - if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - } - // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary - if pt.resp.StatusCode == http.StatusAccepted { - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } else if ao != "" { - pt.URI = ao - pt.Pm = pollingAsyncOperation - } - // if the Location header is invalid and we already have a polling URL - // then we don't care if the Location header URL is malformed. - if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { - pt.Err = err - return err - } else if lh != "" { - if ao == "" { - pt.URI = lh - pt.Pm = pollingLocation - } - } - // make sure a polling URL was found - if pt.URI == "" { - pt.Err = errors.New("didn't get any suitable polling URLs in 202 response") - return pt.Err - } - } - return nil -} - -func (pt *pollingTrackerPut) checkForErrors() error { - err := pt.baseCheckForErrors() - if err != nil { - pt.Err = err - return err - } - // if there are no LRO headers then the body cannot be empty - ao, err := getURLFromAsyncOpHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } - lh, err := getURLFromLocationHeader(pt.resp) - if err != nil { - pt.Err = err - return err - } - if ao == "" && lh == "" && len(pt.rawBody) == 0 { - pt.Err = errors.New("the response did not contain a body") - return pt.Err - } - return nil -} - -func (pt *pollingTrackerPut) provisioningStateApplicable() bool { - return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated -} - -// gets the polling URL from the Azure-AsyncOperation header. -// ensures the URL is well-formed and absolute. -func getURLFromAsyncOpHeader(resp *azcore.Response) (string, error) { - s := resp.Header.Get(headerAsyncOperation) - if s == "" { - return "", nil - } - if !isValidURL(s) { - return "", fmt.Errorf("invalid polling URL '%s'", s) - } - return s, nil -} - -// gets the polling URL from the Location header. -// ensures the URL is well-formed and absolute. -func getURLFromLocationHeader(resp *azcore.Response) (string, error) { - s := resp.Header.Get(headerLocation) - if s == "" { - return "", nil - } - if !isValidURL(s) { - return "", fmt.Errorf("invalid polling URL '%s'", s) - } - return s, nil -} - -// verify that the URL is valid and absolute -func isValidURL(s string) bool { - u, err := url.Parse(s) - return err == nil && u.IsAbs() } - -// pollingMethodType defines a type used for enumerating polling mechanisms. -type pollingMethodType string - -const ( - // pollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. - pollingAsyncOperation pollingMethodType = "AsyncOperation" - - // pollingLocation indicates the polling method uses the Location header. - pollingLocation pollingMethodType = "Location" - - // pollingRequestURI indicates the polling method uses the original request URI. - pollingRequestURI pollingMethodType = "RequestURI" - - // pollingUnknown indicates an unknown polling method and is the default value. - pollingUnknown pollingMethodType = "" -) diff --git a/sdk/armcore/poller_test.go b/sdk/armcore/poller_test.go index fc250c84ed69..b71aa9f7ed95 100644 --- a/sdk/armcore/poller_test.go +++ b/sdk/armcore/poller_test.go @@ -7,196 +7,317 @@ package armcore import ( "context" - "fmt" + "io" + "io/ioutil" "net/http" + "strings" "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) const ( - mockSuccessResp = `{"field": "success"}` + provStateStarted = `{ "properties": { "provisioningState": "Started" } }` + provStateUpdating = `{ "properties": { "provisioningState": "Updating" } }` + provStateSucceeded = `{ "properties": { "provisioningState": "Succeeded" }, "field": "value" }` + provStateFailed = `{ "properties": { "provisioningState": "Failed" } }` + statusInProgress = `{ "status": "InProgress" }` + statusSucceeded = `{ "status": "Succeeded" }` + statusCanceled = `{ "status": "Canceled" }` + successResp = `{ "field": "value" }` + errorResp = `{ "error": "the operation failed" }` ) type mockType struct { Field *string `json:"field,omitempty"` } +type mockError struct { + Msg string `json:"error"` +} + +func (m mockError) Error() string { + return m.Msg +} + func getPipeline(srv *mock.Server) azcore.Pipeline { return azcore.NewPipeline( srv, - azcore.NewRetryPolicy(nil), azcore.NewLogPolicy(nil)) } func handleError(resp *azcore.Response) error { - return fmt.Errorf("error status: %d", resp.StatusCode) + var me mockError + if err := resp.UnmarshalAsJSON(&me); err != nil { + return err + } + return me } -func TestNewPollerTracker(t *testing.T) { +func initialResponse(method, u string, resp io.Reader) *azcore.Response { + req, err := http.NewRequest(method, u, nil) + if err != nil { + panic(err) + } + return &azcore.Response{ + Response: &http.Response{ + Body: ioutil.NopCloser(resp), + ContentLength: -1, + Header: http.Header{}, + Request: req, + }, + } +} + +func TestNewLROPollerAsync(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPost, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - pller, err := NewPoller("testPoller", "", resp, handleError) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - pt := pller.(*poller).pt.(*pollingTrackerPost) - if pt.PollerType != "testPoller" { - t.Fatal("wrong poller type assigned") - } - if pt.resp != resp { - t.Fatal("wrong response assigned") + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) } - if pt.Method != "POST" { - t.Fatal("wrong poller method") + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestResumeTokenFail(t *testing.T) { +func TestNewLROPollerBody(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithBody([]byte(`{"field": "success"}`))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPost, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*body.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - poller, err := NewPoller("testPoller", "", resp, handleError) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - tk, err := poller.ResumeToken() - if err == nil { - t.Fatalf("Expected an error but did not receive one") + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) } - if tk != "" { - t.Fatal("did not expect to receive resume token for a poller in a terminal state") + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestPollUntilDone(t *testing.T) { +func TestNewLROPollerLoc(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderLocation, srv.URL()) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) + } + if _, ok := poller.lro.(*loc.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - resp, err := p.Do(req) + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - poller, err := NewPoller("testPoller", "", resp, handleError) + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) +} + +func TestNewLROPollerNop(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodPost, srv.URL(), strings.NewReader(successResp)) + resp.StatusCode = http.StatusOK + poller, err := NewLROPoller("pollerID", "", resp, getPipeline(srv), handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - if pollerResp == nil { - t.Fatal("Unexpected nil response") + if _, ok := poller.lro.(*nopPoller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") + tk, err := poller.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + if tk != "" { + t.Fatal("expected empty token") + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } } -func TestPutFinalResponseCheck(t *testing.T) { +func TestNewLROPollerInitialRetryAfter(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(`{"other": "other"}`))) - srv.AppendResponse(mock.WithBody([]byte(mockSuccessResp))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.Header.Set("Retry-After", "1") + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - poller, err := NewPoller("testPoller", "", resp, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) +} + +func TestNewLROPollerCanceled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusCanceled)), mock.WithStatusCode(http.StatusOK)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - if pollerResp == nil { - t.Fatal("Unexpected nil response") + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") + _, err = poller.Poll(context.Background()) + if err != nil { + t.Fatal(err) } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + _, err = poller.Poll(context.Background()) + if err == nil { + t.Fatal("unexpected nil error") } } -func TestNewPollerFromResumeTokenTracker(t *testing.T) { +func TestNewLROPollerFailedWithError(t *testing.T) { srv, close := mock.NewServer() defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithBody([]byte(`{"field": "success"}`))) - p := getPipeline(srv) - req, err := azcore.NewRequest(context.Background(), http.MethodPut, srv.URL()) + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(errorResp)), mock.WithStatusCode(http.StatusBadRequest)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(pollers.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - resp, err := p.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if _, ok := poller.lro.(*async.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) } - poller, err := NewPoller("testPoller", "", resp, handleError) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err == nil { + t.Fatal(err) } - tk, err := poller.ResumeToken() + if _, ok := err.(mockError); !ok { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestNewLROPollerSuccessNoContent(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating))) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - poller, err = NewPollerFromResumeToken("testPoller", tk, handleError) + if _, ok := poller.lro.(*body.Poller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) + } + tk, err := poller.ResumeToken() if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - m := &mockType{} - pollerResp, err := poller.PollUntilDone(context.Background(), 1*time.Millisecond, p, m) + poller, err = NewLROPollerFromResumeToken("pollerID", tk, pl, handleError) + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatal(err) } - if pollerResp == nil { - t.Fatal("Unexpected nil response") + if result.Field != nil { + t.Fatal("expected nil result") } - if pollerResp.StatusCode != http.StatusOK { - t.Fatal("Unexpected response status code") +} + +func TestNewLROPollerFail202NoHeaders(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodDelete, srv.URL(), http.NoBody) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewLROPoller("pollerID", "", resp, pl, handleError) + if err == nil { + t.Fatal("unexpected nil error") } - if *m.Field != "success" { - t.Fatalf("Unexpected value for MockType.Field: %s", *m.Field) + if poller != nil { + t.Fatal("expected nil poller") } } diff --git a/sdk/armcore/version.go b/sdk/armcore/version.go index c100fa154219..acdd3fcf5579 100644 --- a/sdk/armcore/version.go +++ b/sdk/armcore/version.go @@ -10,5 +10,5 @@ const ( UserAgent = "armcore/" + Version // Version is the semantic version (see http://semver.org) of this module. - Version = "v0.7.1" + Version = "v0.8.0" )