From bbf5386148ebae79096e2880c17ef89a15e472e9 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 25 Aug 2021 17:14:00 -0700 Subject: [PATCH] Refactor azcore (#15349) * Refactor azcore See the CHANGELOG for details. * add additional test coverage move early header check back to Pipeline.Do * add nil body check * added more test coverage removed spurious error return from NewPoller() sig. fixed bug in FinalResponse() to perform the final GET if the URL is specified but the terminal response contains no content. * fix tests that don't return a final payload * remove bad examples * move constants to make version info easier to find * moved Poller constructors out of root packages for core, moved to azcore/runtime for arm, moved to azcore/arm/runtime also moved ARM RP registration policy to runtime In order to do this, it was necessary to move all pipeline related content to its own internal package azcore/internal/pipeline. This was to avoid problems with circular dependencies. It also groups things that are logically related and cleaned up the code. --- sdk/azcore/CHANGELOG.md | 27 +- sdk/azcore/arm/connection.go | 120 +++++ sdk/azcore/arm/connection_test.go | 202 ++++++++ .../arm/internal/pollers/async/async.go | 139 ++++++ .../arm/internal/pollers/async/async_test.go | 180 +++++++ sdk/azcore/arm/internal/pollers/body/body.go | 111 +++++ .../arm/internal/pollers/body/body_test.go | 207 ++++++++ sdk/azcore/arm/internal/pollers/loc/loc.go | 104 ++++ .../arm/internal/pollers/loc/loc_test.go | 133 ++++++ sdk/azcore/arm/internal/pollers/pollers.go | 68 +++ .../arm/internal/pollers/pollers_test.go | 91 ++++ sdk/azcore/arm/runtime/policy_register_rp.go | 384 +++++++++++++++ .../arm/runtime/policy_register_rp_test.go | 372 +++++++++++++++ sdk/azcore/arm/runtime/poller.go | 81 ++++ sdk/azcore/arm/runtime/poller_test.go | 308 ++++++++++++ sdk/azcore/core.go | 105 +--- sdk/azcore/credential.go | 31 +- sdk/azcore/error.go | 70 --- sdk/azcore/errors.go | 26 + sdk/azcore/example_test.go | 41 +- sdk/azcore/internal/pipeline/pipeline.go | 93 ++++ sdk/azcore/internal/pipeline/pipeline_test.go | 103 ++++ sdk/azcore/internal/pipeline/request.go | 169 +++++++ sdk/azcore/internal/pipeline/request_test.go | 139 ++++++ sdk/azcore/internal/pollers/loc/loc.go | 80 ++++ sdk/azcore/internal/pollers/loc/loc_test.go | 136 ++++++ sdk/azcore/internal/pollers/op/op.go | 132 ++++++ sdk/azcore/internal/pollers/op/op_test.go | 249 ++++++++++ sdk/azcore/internal/pollers/poller.go | 213 +++++++++ sdk/azcore/internal/pollers/poller_test.go | 256 ++++++++++ sdk/azcore/internal/pollers/util.go | 99 ++++ sdk/azcore/internal/pollers/util_test.go | 169 +++++++ sdk/azcore/internal/shared/constants.go | 34 ++ sdk/azcore/internal/shared/shared.go | 148 ++++++ sdk/azcore/internal/shared/shared_test.go | 133 ++++++ sdk/azcore/policy/policy.go | 96 ++++ sdk/azcore/policy/policy_test.go | 54 +++ sdk/azcore/policy_anonymous_credential.go | 14 +- .../policy_anonymous_credential_test.go | 7 +- sdk/azcore/policy_http_header.go | 39 -- sdk/azcore/poller.go | 448 ------------------ sdk/azcore/request.go | 398 ---------------- sdk/azcore/runtime/errors.go | 21 + .../{ => runtime}/policy_body_download.go | 23 +- .../policy_body_download_test.go | 2 +- sdk/azcore/runtime/policy_http_header.go | 31 ++ .../{ => runtime}/policy_http_header_test.go | 11 +- sdk/azcore/{ => runtime}/policy_logging.go | 76 ++- .../{ => runtime}/policy_logging_test.go | 10 +- sdk/azcore/{ => runtime}/policy_retry.go | 105 ++-- sdk/azcore/{ => runtime}/policy_retry_test.go | 56 ++- sdk/azcore/{ => runtime}/policy_telemetry.go | 52 +- .../{ => runtime}/policy_telemetry_test.go | 74 +-- sdk/azcore/runtime/poller.go | 70 +++ sdk/azcore/{ => runtime}/poller_test.go | 414 ++++++++-------- sdk/azcore/runtime/request.go | 228 +++++++++ sdk/azcore/{ => runtime}/request_test.go | 96 ++-- sdk/azcore/{ => runtime}/response.go | 64 +-- sdk/azcore/{ => runtime}/response_test.go | 30 +- .../runtime/transport_default_http_client.go | 35 ++ sdk/azcore/{ => streaming}/progress.go | 13 +- sdk/azcore/{ => streaming}/progress_test.go | 11 +- sdk/azcore/to/to.go | 107 +++++ sdk/azcore/to/to_test.go | 192 ++++++++ sdk/azcore/transport_default_http_client.go | 22 - sdk/azcore/version.go | 15 - 66 files changed, 6020 insertions(+), 1747 deletions(-) create mode 100644 sdk/azcore/arm/connection.go create mode 100644 sdk/azcore/arm/connection_test.go create mode 100644 sdk/azcore/arm/internal/pollers/async/async.go create mode 100644 sdk/azcore/arm/internal/pollers/async/async_test.go create mode 100644 sdk/azcore/arm/internal/pollers/body/body.go create mode 100644 sdk/azcore/arm/internal/pollers/body/body_test.go create mode 100644 sdk/azcore/arm/internal/pollers/loc/loc.go create mode 100644 sdk/azcore/arm/internal/pollers/loc/loc_test.go create mode 100644 sdk/azcore/arm/internal/pollers/pollers.go create mode 100644 sdk/azcore/arm/internal/pollers/pollers_test.go create mode 100644 sdk/azcore/arm/runtime/policy_register_rp.go create mode 100644 sdk/azcore/arm/runtime/policy_register_rp_test.go create mode 100644 sdk/azcore/arm/runtime/poller.go create mode 100644 sdk/azcore/arm/runtime/poller_test.go delete mode 100644 sdk/azcore/error.go create mode 100644 sdk/azcore/errors.go create mode 100644 sdk/azcore/internal/pipeline/pipeline.go create mode 100644 sdk/azcore/internal/pipeline/pipeline_test.go create mode 100644 sdk/azcore/internal/pipeline/request.go create mode 100644 sdk/azcore/internal/pipeline/request_test.go create mode 100644 sdk/azcore/internal/pollers/loc/loc.go create mode 100644 sdk/azcore/internal/pollers/loc/loc_test.go create mode 100644 sdk/azcore/internal/pollers/op/op.go create mode 100644 sdk/azcore/internal/pollers/op/op_test.go create mode 100644 sdk/azcore/internal/pollers/poller.go create mode 100644 sdk/azcore/internal/pollers/poller_test.go create mode 100644 sdk/azcore/internal/pollers/util.go create mode 100644 sdk/azcore/internal/pollers/util_test.go create mode 100644 sdk/azcore/internal/shared/constants.go create mode 100644 sdk/azcore/internal/shared/shared.go create mode 100644 sdk/azcore/internal/shared/shared_test.go create mode 100644 sdk/azcore/policy/policy.go create mode 100644 sdk/azcore/policy/policy_test.go delete mode 100644 sdk/azcore/policy_http_header.go delete mode 100644 sdk/azcore/poller.go delete mode 100644 sdk/azcore/request.go create mode 100644 sdk/azcore/runtime/errors.go rename sdk/azcore/{ => runtime}/policy_body_download.go (80%) rename sdk/azcore/{ => runtime}/policy_body_download_test.go (99%) create mode 100644 sdk/azcore/runtime/policy_http_header.go rename sdk/azcore/{ => runtime}/policy_http_header_test.go (91%) rename sdk/azcore/{ => runtime}/policy_logging.go (60%) rename sdk/azcore/{ => runtime}/policy_logging_test.go (95%) rename sdk/azcore/{ => runtime}/policy_retry.go (60%) rename sdk/azcore/{ => runtime}/policy_retry_test.go (90%) rename sdk/azcore/{ => runtime}/policy_telemetry.go (54%) rename sdk/azcore/{ => runtime}/policy_telemetry_test.go (52%) create mode 100644 sdk/azcore/runtime/poller.go rename sdk/azcore/{ => runtime}/poller_test.go (97%) create mode 100644 sdk/azcore/runtime/request.go rename sdk/azcore/{ => runtime}/request_test.go (82%) rename sdk/azcore/{ => runtime}/response.go (76%) rename sdk/azcore/{ => runtime}/response_test.go (88%) create mode 100644 sdk/azcore/runtime/transport_default_http_client.go rename sdk/azcore/{ => streaming}/progress.go (78%) rename sdk/azcore/{ => streaming}/progress_test.go (89%) create mode 100644 sdk/azcore/to/to.go create mode 100644 sdk/azcore/to/to_test.go delete mode 100644 sdk/azcore/transport_default_http_client.go delete mode 100644 sdk/azcore/version.go diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 56fd3d94be99..edf6588870e8 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -1,13 +1,36 @@ # Release History +## v0.19.0 + +### Breaking Changes +* Split content out of `azcore` into various packages. The intent is to separate content based on its usage (common, uncommon, SDK authors). + * `azcore` has all core functionality. + * `log` contains facilities for configuring in-box logging. + * `policy` is used for configuring pipeline options and creating custom pipeline policies. + * `runtime` contains various helpers used by SDK authors and generated content. + * `streaming` has helpers for streaming IO operations. +* `NewTelemetryPolicy()` now requires module and version parameters and the `Value` option has been removed. + * As a result, the `Request.Telemetry()` method has been removed. +* The telemetry policy now includes the SDK prefix `azsdk-go-` so callers no longer need to provide it. +* The `*http.Request` in `runtime.Request` is no longer anonymously embedded. Use the `Raw()` method to access it. +* The `UserAgent` and `Version` constants have been made internal, `Module` and `Version` respectively. + +### Bug Fixes +* Fixed an issue in the retry policy where the request body could be overwritten after a rewind. + +### Other Changes +* Moved modules `armcore` and `to` content into `arm` and `to` packages respectively. + * The `Pipeline()` method on `armcore.Connection` has been replaced by `NewPipeline()` in `arm.Connection`. It takes module and version parameters used by the telemetry policy. +* Poller logic has been consolidated across ARM and core implementations. + * This required some changes to the internal interfaces for core pollers. +* The core poller types have been improved, including more logging and test coverage. + ## v0.18.1 ### Features Added * Adds an `ETag` type for comparing etags and handling etags on requests * Simplifies the `requestBodyProgess` and `responseBodyProgress` into a single `progress` object -### Breaking Changes - ### Bugs Fixed * `JoinPaths` will preserve query parameters encoded in the `root` url. diff --git a/sdk/azcore/arm/connection.go b/sdk/azcore/arm/connection.go new file mode 100644 index 000000000000..4a7c4f530ed3 --- /dev/null +++ b/sdk/azcore/arm/connection.go @@ -0,0 +1,120 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package arm + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +const ( + // AzureChina is the Azure Resource Manager China cloud endpoint. + AzureChina = "https://management.chinacloudapi.cn/" + // AzureGermany is the Azure Resource Manager Germany cloud endpoint. + AzureGermany = "https://management.microsoftazure.de/" + // AzureGovernment is the Azure Resource Manager US government cloud endpoint. + AzureGovernment = "https://management.usgovcloudapi.net/" + // AzurePublicCloud is the Azure Resource Manager public cloud endpoint. + AzurePublicCloud = "https://management.azure.com/" +) + +// ConnectionOptions contains configuration settings for the connection's pipeline. +// All zero-value fields will be initialized with their default values. +type ConnectionOptions struct { + // AuxiliaryTenants contains a list of additional tenants to be used to authenticate + // across multiple tenants. + AuxiliaryTenants []string + + // HTTPClient sets the transport for making HTTP requests. + HTTPClient policy.Transporter + + // Retry configures the built-in retry policy behavior. + Retry policy.RetryOptions + + // Telemetry configures the built-in telemetry policy behavior. + Telemetry policy.TelemetryOptions + + // Logging configures the built-in logging policy behavior. + Logging policy.LogOptions + + // DisableRPRegistration disables the auto-RP registration policy. + // The default value is false. + DisableRPRegistration bool + + // PerCallPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request. + PerCallPolicies []policy.Policy + + // PerRetryPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request, and for each retry request. + PerRetryPolicies []policy.Policy +} + +// Connection is a connection to an Azure Resource Manager endpoint. +// It contains the base ARM endpoint and a pipeline for making requests. +type Connection struct { + ep string + cred azcore.TokenCredential + opt ConnectionOptions +} + +// NewDefaultConnection creates an instance of the Connection type using the AzurePublicCloud. +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewDefaultConnection(cred azcore.TokenCredential, options *ConnectionOptions) *Connection { + return NewConnection(AzurePublicCloud, cred, options) +} + +// NewConnection creates an instance of the Connection type with the specified endpoint. +// Use this when connecting to clouds other than the Azure public cloud (stack/sovereign clouds). +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewConnection(endpoint string, cred azcore.TokenCredential, options *ConnectionOptions) *Connection { + if options == nil { + options = &ConnectionOptions{} + } + return &Connection{ep: endpoint, cred: cred, opt: *options} +} + +// Endpoint returns the connection's ARM endpoint. +func (con *Connection) Endpoint() string { + return con.ep +} + +// NewPipeline creates a pipeline from the connection's options. +// The telemetry policy, when enabled, will use the specified module and version info. +func (con *Connection) NewPipeline(module, version string) pipeline.Pipeline { + policies := []policy.Policy{} + if !con.opt.Telemetry.Disabled { + policies = append(policies, azruntime.NewTelemetryPolicy(module, version, &con.opt.Telemetry)) + } + if !con.opt.DisableRPRegistration { + regRPOpts := armruntime.RegistrationOptions{ + HTTPClient: con.opt.HTTPClient, + Logging: con.opt.Logging, + Retry: con.opt.Retry, + Telemetry: con.opt.Telemetry, + } + policies = append(policies, armruntime.NewRPRegistrationPolicy(con.ep, con.cred, ®RPOpts)) + } + policies = append(policies, con.opt.PerCallPolicies...) + policies = append(policies, azruntime.NewRetryPolicy(&con.opt.Retry)) + policies = append(policies, con.opt.PerRetryPolicies...) + policies = append(policies, + con.cred.NewAuthenticationPolicy( + azruntime.AuthenticationOptions{ + TokenRequest: policy.TokenRequestOptions{ + Scopes: []string{shared.EndpointToScope(con.ep)}, + }, + AuxiliaryTenants: con.opt.AuxiliaryTenants, + }, + ), + azruntime.NewLogPolicy(&con.opt.Logging)) + return azruntime.NewPipeline(con.opt.HTTPClient, policies...) +} diff --git a/sdk/azcore/arm/connection_test.go b/sdk/azcore/arm/connection_test.go new file mode 100644 index 000000000000..be0bf26453e6 --- /dev/null +++ b/sdk/azcore/arm/connection_test.go @@ -0,0 +1,202 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package arm + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +type mockTokenCred struct{} + +func (mockTokenCred) NewAuthenticationPolicy(azruntime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { + return req.Next() + }) +} + +func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { + return &azcore.AccessToken{ + Token: "abc123", + ExpiresOn: time.Now().Add(1 * time.Hour), + }, nil +} + +const rpUnregisteredResp = `{ + "error":{ + "code":"MissingSubscriptionRegistration", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", + "details":[{ + "code":"MissingSubscriptionRegistration", + "target":"Microsoft.Storage", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." + } + ] + } +}` + +func TestNewDefaultConnection(t *testing.T) { + opt := ConnectionOptions{} + con := NewDefaultConnection(mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != AzurePublicCloud { + t.Fatalf("unexpected endpoint %s", ep) + } +} + +func TestNewConnection(t *testing.T) { + const customEndpoint = "https://contoso.com/fake/endpoint" + con := NewConnection(customEndpoint, mockTokenCred{}, nil) + if ep := con.Endpoint(); ep != customEndpoint { + t.Fatalf("unexpected endpoint %s", ep) + } +} + +func TestNewConnectionWithOptions(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse() + opt := ConnectionOptions{} + opt.HTTPClient = srv + con := NewConnection(srv.URL(), mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, "azsdk-go-armtest/v1.2.3") { + t.Fatalf("unexpected User-Agent %s", ua) + } +} + +func TestNewConnectionWithCustomTelemetry(t *testing.T) { + const myTelemetry = "something" + srv, close := mock.NewServer() + defer close() + srv.AppendResponse() + opt := ConnectionOptions{} + opt.HTTPClient = srv + opt.Telemetry.ApplicationID = myTelemetry + con := NewConnection(srv.URL(), mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + if opt.Telemetry.ApplicationID != myTelemetry { + t.Fatalf("telemetry was modified: %s", opt.Telemetry.ApplicationID) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, myTelemetry+" "+"azsdk-go-armtest/v1.2.3") { + t.Fatalf("unexpected User-Agent %s", ua) + } +} + +func TestDisableAutoRPRegistration(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{DisableRPRegistration: true}) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // log only RP registration + log.SetClassifications(armruntime.LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + // shouldn't be any log entries + if logEntries != 0 { + t.Fatalf("expected 0 log entries, got %d", logEntries) + } +} + +// policy that tracks the number of times it was invoked +type countingPolicy struct { + count int +} + +func (p *countingPolicy) Do(req *policy.Request) (*http.Response, error) { + p.count++ + return req.Next() +} + +func TestConnectionWithCustomPolicies(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response is a failure to trigger retry + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + perCallPolicy := countingPolicy{} + perRetryPolicy := countingPolicy{} + con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{ + DisableRPRegistration: true, + PerCallPolicies: []policy.Policy{&perCallPolicy}, + PerRetryPolicies: []policy.Policy{&perRetryPolicy}, + }) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if perCallPolicy.count != 1 { + t.Fatalf("unexpected per call policy count %d", perCallPolicy.count) + } + if perRetryPolicy.count != 2 { + t.Fatalf("unexpected per retry policy count %d", perRetryPolicy.count) + } +} diff --git a/sdk/azcore/arm/internal/pollers/async/async.go b/sdk/azcore/arm/internal/pollers/async/async.go new file mode 100644 index 000000000000..3a3cd0a3ca96 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/async/async.go @@ -0,0 +1,139 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "errors" + "fmt" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Azure-AsyncOperation" + +const ( + finalStateAsync = "azure-async-operation" + finalStateLoc = "location" //nolint + finalStateOrig = "original-uri" +) + +// Applicable returns true if the LRO is using Azure-AsyncOperation. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderAzureAsync) != "" +} + +// Poller is an LRO poller that uses the Azure-AsyncOperation pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL from Azure-AsyncOperation header. + AsyncURL string `json:"asyncURL"` + + // The URL from Location header. + LocURL string `json:"locURL"` + + // The URL from the initial LRO request. + OrigURL string `json:"origURL"` + + // The HTTP method from the initial LRO request. + Method string `json:"method"` + + // The value of final-state-via from swagger, can be the empty string. + FinalState string `json:"finalState"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response and final-state type. +func New(resp *http.Response, finalState string, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Azure-AsyncOperation poller.") + asyncURL := resp.Header.Get(shared.HeaderAzureAsync) + if asyncURL == "" { + return nil, errors.New("response is missing Azure-AsyncOperation header") + } + if !pollers.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + AsyncURL: asyncURL, + LocURL: resp.Header.Get(shared.HeaderLocation), + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + } + // check for provisioning state + state, err := armpollers.GetProvisioningState(resp) + if errors.Is(err, shared.ErrNoBody) || state == "" { + // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST + // contain a provisioning state. to maintain compat with track 1 and other implementations + // we are explicitly relaxing this requirement. + /*if resp.Request.Method == http.MethodPut { + // initial response for a PUT requires a provisioning state + return nil, err + }*/ + // for DELETE/PATCH/POST, provisioning state is optional + state = pollers.StatusInProgress + } else if err != nil { + return nil, err + } + p.CurState = state + return p, nil +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + state, err := armpollers.GetStatus(resp) + if err != nil { + return err + } else if state == "" { + return errors.New("the response did not contain a status") + } + p.CurState = state + return nil +} + +// FinalGetURL returns the URL to perform a final GET for the payload, or the empty string if not required. +func (p *Poller) FinalGetURL() string { + if p.Method == http.MethodPatch || p.Method == http.MethodPut { + // for PATCH and PUT, the final GET is on the original resource URL + return p.OrigURL + } else if p.Method == http.MethodPost { + if p.FinalState == finalStateAsync { + return "" + } else if p.FinalState == finalStateOrig { + return p.OrigURL + } else if p.LocURL != "" { + // ideally FinalState would be set to "location" but it isn't always. + // must check last due to more permissive condition. + return p.LocURL + } + } + return "" +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.AsyncURL +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/async/async_test.go b/sdk/azcore/arm/internal/pollers/async/async_test.go new file mode 100644 index 000000000000..a87f503a6a77 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/async/async_test.go @@ -0,0 +1,180 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + } +} + +func pollingResponse(resp io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Azure-AsyncOperation should not be applicable") + } + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + if !Applicable(resp) { + t.Fatal("having Azure-AsyncOperation should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(strings.NewReader(`{ "status": "InProgress" }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewDeleteNoProvState(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewPutNoProvState(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewFinalGetLocation(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, locURL) + poller, err := New(resp, "location", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != locURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFinalGetOrigin(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, locURL) + poller, err := New(resp, "original-uri", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewPutNoProvStateOnUpdate(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(strings.NewReader("{}"))); err == nil { + t.Fatal("unexpected nil error") + } +} diff --git a/sdk/azcore/arm/internal/pollers/body/body.go b/sdk/azcore/arm/internal/pollers/body/body.go new file mode 100644 index 000000000000..ea5fa6b468e0 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/body/body.go @@ -0,0 +1,111 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Body" + +// Applicable returns true if the LRO is using no headers, just provisioning state. +// This is only applicable to PATCH and PUT methods and assumes no polling headers. +func Applicable(resp *http.Response) bool { + // we can't check for absense of headers due to some misbehaving services + // like redis that return a Location header but don't actually use that protocol + return resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut +} + +// Poller is an LRO poller that uses the Body pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Body poller.") + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: resp.Request.URL.String(), + } + // default initial state to InProgress. depending on the HTTP + // status code and provisioning state, we might change the value. + curState := pollers.StatusInProgress + provState, err := armpollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return nil, err + } + if resp.StatusCode == http.StatusCreated && provState != "" { + // absense of provisioning state is ok for a 201, means the operation is in progress + curState = provState + } else if resp.StatusCode == http.StatusOK { + if provState != "" { + curState = provState + } else if provState == "" { + // for a 200, absense of provisioning state indicates success + curState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + curState = pollers.StatusSucceeded + } + p.CurState = curState + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + return nil + } + state, err := armpollers.GetProvisioningState(resp) + if errors.Is(err, shared.ErrNoBody) { + // a missing response body in non-204 case is an error + return err + } else if state == "" { + // a response body without provisioning state is considered terminal success + state = pollers.StatusSucceeded + } else if err != nil { + return err + } + p.CurState = state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (*Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/body/body_test.go b/sdk/azcore/arm/internal/pollers/body/body_test.go new file mode 100644 index 000000000000..aa19670dbd4e --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/body/body_test.go @@ -0,0 +1,207 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + } +} + +func pollingResponse(status int, resp io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + StatusCode: status, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Request: &http.Request{ + Method: http.MethodDelete, + }, + } + if Applicable(resp) { + t.Fatal("method DELETE should not be applicable") + } + resp.Request.Method = http.MethodPatch + if !Applicable(resp) { + t.Fatal("method PATCH should be applicable") + } + resp.Request.Method = http.MethodPut + if !Applicable(resp) { + t.Fatal("method PUT should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusCreated + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateNoProvStateFail(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, http.NoBody)) + if err == nil { + t.Fatal("unexpected nil error") + } + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestUpdateNoProvStateSuccess(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{}`))) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateNoProvState204(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)) + if err != nil { + t.Fatal(err) + } +} + +func TestNewNoInitialProvStateOK(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewNoInitialProvStateNC(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusNoContent + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/arm/internal/pollers/loc/loc.go b/sdk/azcore/arm/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..a1b8a23234fa --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/loc/loc.go @@ -0,0 +1,104 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "errors" + "fmt" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "ARM-Location" + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *http.Response) bool { + return resp.StatusCode == http.StatusAccepted && resp.Header.Get(shared.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(shared.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: locURL, + CurState: pollers.StatusInProgress, + } + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + // location polling can return an updated polling URL + if h := resp.Header.Get(shared.HeaderLocation); h != "" { + p.PollURL = h + } + if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { + // if a 200/201 returns a provisioning state, use that instead + state, err := armpollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return err + } + if state != "" { + p.CurState = state + } else { + // a 200/201 with no provisioning state indicates success + p.CurState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + } else if resp.StatusCode > 399 && resp.StatusCode < 500 { + p.CurState = pollers.StatusFailed + } + // a 202 falls through, means the LRO is still in progress and we don't check for provisioning state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (p *Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/loc/loc_test.go b/sdk/azcore/arm/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..06365a1d6ec6 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/loc/loc_test.go @@ -0,0 +1,133 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL1 = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/updated" +) + +func initialResponse(method string) *http.Response { + return &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } +} + +func pollingResponse(statusCode int, body io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + StatusCode: statusCode, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } + if Applicable(resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + if !Applicable(resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(shared.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusConflict, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Failed" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateWithProvState(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(shared.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "Updating" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusOK, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/arm/internal/pollers/pollers.go b/sdk/azcore/arm/internal/pollers/pollers.go new file mode 100644 index 000000000000..3b9f5581c017 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/pollers.go @@ -0,0 +1,68 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// provisioningState returns the provisioning state from the response or the empty string. +func provisioningState(jsonBody map[string]interface{}) string { + jsonProps, ok := jsonBody["properties"] + if !ok { + return "" + } + props, ok := jsonProps.(map[string]interface{}) + if !ok { + return "" + } + rawPs, ok := props["provisioningState"] + if !ok { + return "" + } + ps, ok := rawPs.(string) + if !ok { + return "" + } + return ps +} + +// status returns the status from the response or the empty string. +func status(jsonBody map[string]interface{}) string { + rawStatus, ok := jsonBody["status"] + if !ok { + return "" + } + status, ok := rawStatus.(string) + if !ok { + return "" + } + return status +} + +// GetStatus returns the LRO's status from the response body. +// Typically used for Azure-AsyncOperation flows. +// If there is no status in the response body the empty string is returned. +func GetStatus(resp *http.Response) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + return status(jsonBody), nil +} + +// GetProvisioningState returns the LRO's state from the response body. +// If there is no state in the response body the empty string is returned. +func GetProvisioningState(resp *http.Response) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + return provisioningState(jsonBody), nil +} diff --git a/sdk/azcore/arm/internal/pollers/pollers_test.go b/sdk/azcore/arm/internal/pollers/pollers_test.go new file mode 100644 index 000000000000..10808c908256 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/pollers_test.go @@ -0,0 +1,91 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +func TestGetStatusSuccess(t *testing.T) { + const jsonBody = `{ "status": "InProgress" }` + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + } + status, err := GetStatus(resp) + if err != nil { + t.Fatal(err) + } + if status != "InProgress" { + t.Fatalf("unexpected status %s", status) + } +} + +func TestGetNoBody(t *testing.T) { + resp := &http.Response{ + Body: http.NoBody, + } + status, err := GetStatus(resp) + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } + status, err = GetProvisioningState(resp) + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } +} + +func TestGetStatusError(t *testing.T) { + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + } + status, err := GetStatus(resp) + if err != nil { + t.Fatal(err) + } + if status != "" { + t.Fatalf("expected empty status, got %s", status) + } +} + +func TestGetProvisioningState(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Canceled" } }` + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + } + state, err := GetProvisioningState(resp) + if err != nil { + t.Fatal(err) + } + if state != "Canceled" { + t.Fatalf("unexpected status %s", state) + } +} + +func TestGetProvisioningStateError(t *testing.T) { + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + } + state, err := GetProvisioningState(resp) + if err != nil { + t.Fatal(err) + } + if state != "" { + t.Fatalf("expected empty provisioning state, got %s", state) + } +} diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go new file mode 100644 index 000000000000..9b5e16250277 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -0,0 +1,384 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +const ( + // LogRPRegistration entries contain information specific to the automatic registration of an RP. + // Entries of this classification are written IFF the policy needs to take any action. + LogRPRegistration log.Classification = "RPRegistration" +) + +// RegistrationOptions configures the registration policy's behavior. +// All zero-value fields will be initialized with their default values. +type RegistrationOptions struct { + // MaxAttempts is the total number of times to attempt automatic registration + // in the event that an attempt fails. + // The default value is 3. + // Set to a value less than zero to disable the policy. + MaxAttempts int + + // PollingDelay is the amount of time to sleep between polling intervals. + // The default value is 15 seconds. + // A value less than zero means no delay between polling intervals (not recommended). + PollingDelay time.Duration + + // PollingDuration is the amount of time to wait before abandoning polling. + // The default valule is 5 minutes. + // NOTE: Setting this to a small value might cause the policy to prematurely fail. + PollingDuration time.Duration + + // HTTPClient sets the transport for making HTTP requests. + HTTPClient policy.Transporter + + // Retry configures the built-in retry policy behavior. + Retry policy.RetryOptions + + // Telemetry configures the built-in telemetry policy behavior. + Telemetry policy.TelemetryOptions + + // Logging configures the built-in logging policy behavior. + Logging policy.LogOptions +} + +// init sets any default values +func (r *RegistrationOptions) init() { + if r.MaxAttempts == 0 { + r.MaxAttempts = 3 + } else if r.MaxAttempts < 0 { + r.MaxAttempts = 0 + } + if r.PollingDelay == 0 { + r.PollingDelay = 15 * time.Second + } else if r.PollingDelay < 0 { + r.PollingDelay = 0 + } + if r.PollingDuration == 0 { + r.PollingDuration = 5 * time.Minute + } +} + +// NewRPRegistrationPolicy creates a policy object configured using the specified endpoint, +// credentials and options. The policy controls if an unregistered resource provider should +// automatically be registered. See https://aka.ms/rps-not-found for more information. +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewRPRegistrationPolicy(endpoint string, cred azcore.Credential, o *RegistrationOptions) policy.Policy { + if o == nil { + o = &RegistrationOptions{} + } + p := &rpRegistrationPolicy{ + endpoint: endpoint, + pipeline: runtime.NewPipeline(o.HTTPClient, + runtime.NewTelemetryPolicy(shared.Module, shared.Version, &o.Telemetry), + runtime.NewRetryPolicy(&o.Retry), + cred.NewAuthenticationPolicy(runtime.AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}}), + runtime.NewLogPolicy(&o.Logging)), + options: *o, + } + // init the copy + p.options.init() + return p +} + +type rpRegistrationPolicy struct { + endpoint string + pipeline pipeline.Pipeline + options RegistrationOptions +} + +func (r *rpRegistrationPolicy) Do(req *policy.Request) (*http.Response, error) { + if r.options.MaxAttempts == 0 { + // policy is disabled + return req.Next() + } + const unregisteredRPCode = "MissingSubscriptionRegistration" + const registeredState = "Registered" + var rp string + var resp *http.Response + for attempts := 0; attempts < r.options.MaxAttempts; attempts++ { + var err error + // make the original request + resp, err = req.Next() + // getting a 409 is the first indication that the RP might need to be registered, check error response + if err != nil || resp.StatusCode != http.StatusConflict { + return resp, err + } + var reqErr requestError + if err = runtime.UnmarshalAsJSON(resp, &reqErr); err != nil { + return resp, err + } + if reqErr.ServiceError == nil { + return resp, errors.New("missing error information") + } + if !strings.EqualFold(reqErr.ServiceError.Code, unregisteredRPCode) { + // not a 409 due to unregistered RP + return resp, err + } + // RP needs to be registered. start by getting the subscription ID from the original request + subID, err := getSubscription(req.Raw().URL.Path) + if err != nil { + return resp, err + } + // now get the RP from the error + rp, err = getProvider(reqErr) + if err != nil { + return resp, err + } + logRegistrationExit := func(v interface{}) { + log.Writef(LogRPRegistration, "END registration for %s: %v", rp, v) + } + log.Writef(LogRPRegistration, "BEGIN registration for %s", rp) + // create client and make the registration request + // we use the scheme and host from the original request + rpOps := &providersOperations{ + p: r.pipeline, + u: r.endpoint, + subID: subID, + } + if _, err = rpOps.Register(req.Raw().Context(), rp); err != nil { + logRegistrationExit(err) + return resp, err + } + // RP was registered, however we need to wait for the registration to complete + pollCtx, pollCancel := context.WithTimeout(req.Raw().Context(), r.options.PollingDuration) + var lastRegState string + for { + // get the current registration state + getResp, err := rpOps.Get(pollCtx, rp) + if err != nil { + pollCancel() + logRegistrationExit(err) + return resp, err + } + if getResp.Provider.RegistrationState != nil && !strings.EqualFold(*getResp.Provider.RegistrationState, lastRegState) { + // registration state has changed, or was updated for the first time + lastRegState = *getResp.Provider.RegistrationState + log.Writef(LogRPRegistration, "registration state is %s", lastRegState) + } + if strings.EqualFold(lastRegState, registeredState) { + // registration complete + pollCancel() + logRegistrationExit(lastRegState) + break + } + // wait before trying again + select { + case <-time.After(r.options.PollingDelay): + // continue polling + case <-pollCtx.Done(): + pollCancel() + logRegistrationExit(pollCtx.Err()) + return resp, pollCtx.Err() + } + } + // RP was successfully registered, retry the original request + err = req.RewindBody() + if err != nil { + return resp, err + } + } + // if we get here it means we exceeded the number of attempts + return resp, fmt.Errorf("exceeded attempts to register %s", rp) +} + +func getSubscription(path string) (string, error) { + parts := strings.Split(path, "/") + for i, v := range parts { + if v == "subscriptions" && (i+1) < len(parts) { + return parts[i+1], nil + } + } + return "", fmt.Errorf("failed to obtain subscription ID from %s", path) +} + +func getProvider(re requestError) (string, error) { + if len(re.ServiceError.Details) > 0 { + return re.ServiceError.Details[0].Target, nil + } + return "", errors.New("unexpected empty Details") +} + +// minimal error definitions to simplify detection +type requestError struct { + ServiceError *serviceError `json:"error"` +} + +type serviceError struct { + Code string `json:"code"` + Details []serviceErrorDetails `json:"details"` +} + +type serviceErrorDetails struct { + Code string `json:"code"` + Target string `json:"target"` +} + +/////////////////////////////////////////////////////////////////////////////////////////////// +// the following code was copied from module armresources, providers.go and models.go +// only the minimum amount of code was copied to get this working and some edits were made. +/////////////////////////////////////////////////////////////////////////////////////////////// + +type providersOperations struct { + p pipeline.Pipeline + u string + subID string +} + +// Get - Gets the specified resource provider. +func (client *providersOperations) Get(ctx context.Context, resourceProviderNamespace string) (*ProviderResponse, error) { + req, err := client.getCreateRequest(ctx, resourceProviderNamespace) + if err != nil { + return nil, err + } + resp, err := client.p.Do(req) + if err != nil { + return nil, err + } + result, err := client.getHandleResponse(resp) + if err != nil { + return nil, err + } + return result, nil +} + +// getCreateRequest creates the Get request. +func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { + urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}" + urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) + urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) + req, err := runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(client.u, urlPath)) + if err != nil { + return nil, err + } + query := req.Raw().URL.Query() + query.Set("api-version", "2019-05-01") + req.Raw().URL.RawQuery = query.Encode() + return req, nil +} + +// getHandleResponse handles the Get response. +func (client *providersOperations) getHandleResponse(resp *http.Response) (*ProviderResponse, error) { + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, client.getHandleError(resp) + } + result := ProviderResponse{RawResponse: resp} + err := runtime.UnmarshalAsJSON(resp, &result.Provider) + if err != nil { + return nil, err + } + return &result, err +} + +// getHandleError handles the Get error response. +func (client *providersOperations) getHandleError(resp *http.Response) error { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return shared.NewResponseError(err, resp) + } + if len(body) == 0 { + return shared.NewResponseError(errors.New(resp.Status), resp) + } + return shared.NewResponseError(errors.New(string(body)), resp) +} + +// Register - Registers a subscription with a resource provider. +func (client *providersOperations) Register(ctx context.Context, resourceProviderNamespace string) (*ProviderResponse, error) { + req, err := client.registerCreateRequest(ctx, resourceProviderNamespace) + if err != nil { + return nil, err + } + resp, err := client.p.Do(req) + if err != nil { + return nil, err + } + result, err := client.registerHandleResponse(resp) + if err != nil { + return nil, err + } + return result, nil +} + +// registerCreateRequest creates the Register request. +func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { + urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register" + urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) + urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) + req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(client.u, urlPath)) + if err != nil { + return nil, err + } + query := req.Raw().URL.Query() + query.Set("api-version", "2019-05-01") + req.Raw().URL.RawQuery = query.Encode() + return req, nil +} + +// registerHandleResponse handles the Register response. +func (client *providersOperations) registerHandleResponse(resp *http.Response) (*ProviderResponse, error) { + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, client.registerHandleError(resp) + } + result := ProviderResponse{RawResponse: resp} + err := runtime.UnmarshalAsJSON(resp, &result.Provider) + if err != nil { + return nil, err + } + return &result, err +} + +// registerHandleError handles the Register error response. +func (client *providersOperations) registerHandleError(resp *http.Response) error { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return shared.NewResponseError(err, resp) + } + if len(body) == 0 { + return shared.NewResponseError(errors.New(resp.Status), resp) + } + return shared.NewResponseError(errors.New(string(body)), resp) +} + +// ProviderResponse is the response envelope for operations that return a Provider type. +type ProviderResponse struct { + // Resource provider information. + Provider *Provider + + // RawResponse contains the underlying HTTP response. + RawResponse *http.Response +} + +// Provider - Resource provider information. +type Provider struct { + // The provider ID. + ID *string `json:"id,omitempty"` + + // The namespace of the resource provider. + Namespace *string `json:"namespace,omitempty"` + + // The registration policy of the resource provider. + RegistrationPolicy *string `json:"registrationPolicy,omitempty"` + + // The registration state of the resource provider. + RegistrationState *string `json:"registrationState,omitempty"` +} diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go new file mode 100644 index 000000000000..05b5318f5567 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -0,0 +1,372 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const rpUnregisteredResp = `{ + "error":{ + "code":"MissingSubscriptionRegistration", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", + "details":[{ + "code":"MissingSubscriptionRegistration", + "target":"Microsoft.Storage", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." + } + ] + } +}` + +// some content was omitted here as it's not relevant +const rpRegisteringResp = `{ + "id": "/subscriptions/00000000-0000-0000-0000-000000000000/providers/Microsoft.Storage", + "namespace": "Microsoft.Storage", + "registrationState": "Registering", + "registrationPolicy": "RegistrationRequired" +}` + +// some content was omitted here as it's not relevant +const rpRegisteredResp = `{ + "id": "/subscriptions/00000000-0000-0000-0000-000000000000/providers/Microsoft.Storage", + "namespace": "Microsoft.Storage", + "registrationState": "Registered", + "registrationPolicy": "RegistrationRequired" +}` + +const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/fakeResourceGroupo/providers/Microsoft.Storage/storageAccounts/fakeAccountName" + +func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { + def := RegistrationOptions{} + def.HTTPClient = t + def.PollingDelay = 100 * time.Millisecond + def.PollingDuration = 1 * time.Second + return &def +} + +type mockTokenCred struct{} + +func (mockTokenCred) NewAuthenticationPolicy(runtime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { + return req.Next() + }) +} + +func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { + return &azcore.AccessToken{ + Token: "abc123", + ExpiresOn: time.Now().Add(1 * time.Hour), + }, nil +} + +func TestRPRegistrationPolicySuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress + srv.RepeatResponse(5, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) + // polling response, successful registration + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteredResp))) + // response for original request (different status code than any of the other responses) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), mockTokenCred{}, testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + if resp.Request.URL.Path != requestEndpoint { + t.Fatalf("unexpected path in response %s", resp.Request.URL.Path) + } + // should be four entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is when state transitions to success + // 4th is for end + if logEntries != 4 { + t.Fatalf("expected 4 log entries, got %d", logEntries) + } +} + +func TestRPRegistrationPolicyNA(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // response indicates no RP registration is required, policy does nothing + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + log.SetListener(func(cls log.Classification, msg string) { + t.Fatalf("unexpected log entry %s: %s", cls, msg) + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicy409Other(t *testing.T) { + const failedResp = `{ + "error":{ + "code":"CannotDoTheThing", + "message":"Something failed in your API call.", + "details":[{ + "code":"ThisIsForTesting", + "message":"This is fake." + } + ] + } + }` + srv, close := mock.NewServer() + defer close() + // test getting a 409 but not due to registration required + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(failedResp))) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + log.SetListener(func(cls log.Classification, msg string) { + t.Fatalf("unexpected log entry %s: %s", cls, msg) + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyTimesOut(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress but slow + // tests registration takes too long, times out + srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(400*time.Millisecond)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected DeadlineExceeded, got %v", err) + } + // should be three entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is the deadline exceeded error + if logEntries != 3 { + t.Fatalf("expected 3 log entries, got %d", logEntries) + } + // we should get the response from the original request + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyExceedsAttempts(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // add a cycle of unregistered->registered so that we keep retrying and hit the cap + for i := 0; i < 4; i++ { + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress + srv.RepeatResponse(2, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) + // polling response, successful registration + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteredResp))) + } + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if !strings.HasPrefix(err.Error(), "exceeded attempts to register Microsoft.Storage") { + t.Fatalf("unexpected error message %s", err.Error()) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + if resp.Request.URL.Path != requestEndpoint { + t.Fatalf("unexpected path in response %s", resp.Request.URL.Path) + } + // should be 4 entries for each attempt, total 12 entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is when state transitions to success + // 4th is for end + if logEntries != 12 { + t.Fatalf("expected 12 log entries, got %d", logEntries) + } +} + +// test cancelling registration +func TestRPRegistrationPolicyCanCancel(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress but slow so we have time to cancel + srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(300*time.Millisecond)) + opts := RegistrationOptions{} + opts.HTTPClient = srv + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), &opts)) + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + + wg := &sync.WaitGroup{} + wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + var resp *http.Response + var err error + go func() { + defer wg.Done() + // create request and start pipeline + var req *policy.Request + req, err = runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + return + } + resp, err = pl.Do(req) + }() + + // wait for a bit then cancel the operation + time.Sleep(500 * time.Millisecond) + cancel() + wg.Wait() + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected Canceled error, got %v", err) + } + // there should be 1 or 2 entries depending on the timing + if logEntries == 0 { + t.Fatal("didn't get any log entries") + } + // should have original response + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyDisabled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + ops := testRPRegistrationOptions(srv) + ops.MaxAttempts = -1 + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), ops)) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + // shouldn't be any log entries + if logEntries != 0 { + t.Fatalf("expected 0 log entries, got %d", logEntries) + } +} diff --git a/sdk/azcore/arm/runtime/poller.go b/sdk/azcore/arm/runtime/poller.go new file mode 100644 index 000000000000..f4e6df175bb0 --- /dev/null +++ b/sdk/azcore/arm/runtime/poller.go @@ -0,0 +1,81 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// NewPoller creates a Poller based on the provided initial response. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. +func NewPoller(pollerID string, finalState string, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !pollers.StatusCodeValid(resp) { + return nil, errors.New("the LRO failed or was cancelled") + } + // determine the polling method + var lro pollers.Operation + var err error + if async.Applicable(resp) { + lro, err = async.New(resp, finalState, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else if body.Applicable(resp) { + // must test body poller last as it's a subset of the other pollers. + // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) + lro, err = body.New(resp, pollerID) + } else if m := resp.Request.Method; resp.StatusCode == http.StatusAccepted && (m == http.MethodDelete || m == http.MethodPost) { + // if we get here it means we have a 202 with no polling headers. + // for DELETE and POST this is a hard error per ARM RPC spec. + return nil, errors.New("response is missing polling URL") + } else { + lro = &pollers.NopPoller{} + } + if err != nil { + return nil, err + } + return pollers.NewPoller(lro, resp, pl, eu), nil +} + +// NewPollerFromResumeToken creates a Poller from a resume token string. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. +func NewPollerFromResumeToken(pollerID string, token string, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + kind, err := pollers.KindFromToken(pollerID, token) + if err != nil { + return nil, err + } + // now rehydrate the poller based on the encoded poller type + var lro pollers.Operation + switch kind { + case async.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", async.Kind) + lro = &async.Poller{} + case loc.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", loc.Kind) + lro = &loc.Poller{} + case body.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", body.Kind) + lro = &body.Poller{} + default: + return nil, fmt.Errorf("unhandled poller type %s", kind) + } + if err = json.Unmarshal([]byte(token), lro); err != nil { + return nil, err + } + return pollers.NewPoller(lro, nil, pl, eu), nil +} diff --git a/sdk/azcore/arm/runtime/poller_test.go b/sdk/azcore/arm/runtime/poller_test.go new file mode 100644 index 000000000000..6ac0429eeb1d --- /dev/null +++ b/sdk/azcore/arm/runtime/poller_test.go @@ -0,0 +1,308 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "reflect" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + provStateStarted = `{ "properties": { "provisioningState": "Started" } }` + provStateUpdating = `{ "properties": { "provisioningState": "Updating" } }` + provStateSucceeded = `{ "properties": { "provisioningState": "Succeeded" }, "field": "value" }` + provStateFailed = `{ "properties": { "provisioningState": "Failed" } }` //nolint + statusInProgress = `{ "status": "InProgress" }` + statusSucceeded = `{ "status": "Succeeded" }` + statusCanceled = `{ "status": "Canceled" }` + successResp = `{ "field": "value" }` + errorResp = `{ "error": "the operation failed" }` +) + +type mockType struct { + Field *string `json:"field,omitempty"` +} + +type mockError struct { + Msg string `json:"error"` +} + +func (m mockError) Error() string { + return m.Msg +} + +func getPipeline(srv *mock.Server) pipeline.Pipeline { + return runtime.NewPipeline( + srv, + runtime.NewLogPolicy(nil)) +} + +func handleError(resp *http.Response) error { + var me mockError + if err := runtime.UnmarshalAsJSON(resp, &me); err != nil { + return err + } + return me +} + +func initialResponse(method, u string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, u, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + ContentLength: -1, + Header: http.Header{}, + Request: req, + } +} + +func TestNewPollerAsync(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerBody(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&body.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerLoc(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderLocation, srv.URL()) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&loc.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerInitialRetryAfter(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.Header.Set("Retry-After", "1") + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerCanceled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusCanceled)), mock.WithStatusCode(http.StatusOK)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + _, err = poller.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = poller.Poll(context.Background()) + if err == nil { + t.Fatal("unexpected nil error") + } +} + +func TestNewPollerFailedWithError(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(errorResp)), mock.WithStatusCode(http.StatusBadRequest)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err == nil { + t.Fatal(err) + } + if _, ok := err.(mockError); !ok { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestNewPollerSuccessNoContent(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating))) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&body.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if result.Field != nil { + t.Fatal("expected nil result") + } +} + +func TestNewPollerFail202NoHeaders(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodDelete, srv.URL(), http.NoBody) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index cfd2ede845ef..660b0cbe07ca 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -7,111 +7,11 @@ package azcore import ( - "errors" - "io" - "net/http" "reflect" -) -const ( - headerContentLength = "Content-Length" - headerContentType = "Content-Type" - headerOperationLocation = "Operation-Location" - headerLocation = "Location" - headerRetryAfter = "Retry-After" - headerUserAgent = "User-Agent" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" ) -// Policy represents an extensibility point for the Pipeline that can mutate the specified -// Request and react to the received Response. -type Policy interface { - // Do applies the policy to the specified Request. When implementing a Policy, mutate the - // request before calling req.Next() to move on to the next policy, and respond to the result - // before returning to the caller. - Do(req *Request) (*http.Response, error) -} - -// policyFunc is a type that implements the Policy interface. -// Use this type when implementing a stateless policy as a first-class function. -type policyFunc func(*Request) (*http.Response, error) - -// Do implements the Policy interface on PolicyFunc. -func (pf policyFunc) Do(req *Request) (*http.Response, error) { - return pf(req) -} - -// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. -type Transporter interface { - // Do sends the HTTP request and returns the HTTP response or error. - Do(req *http.Request) (*http.Response, error) -} - -// used to adapt a TransportPolicy to a Policy -type transportPolicy struct { - trans Transporter -} - -func (tp transportPolicy) Do(req *Request) (*http.Response, error) { - resp, err := tp.trans.Do(req.Request) - if err != nil { - return nil, err - } else if resp == nil { - // there was no response and no error (rare but can happen) - // this ensures the retry policy will retry the request - return nil, errors.New("received nil response") - } - return resp, nil -} - -// Pipeline represents a primitive for sending HTTP requests and receiving responses. -// Its behavior can be extended by specifying policies during construction. -type Pipeline struct { - policies []Policy -} - -// NewPipeline creates a new Pipeline object from the specified Transport and Policies. -// If no transport is provided then the default *http.Client transport will be used. -func NewPipeline(transport Transporter, policies ...Policy) Pipeline { - if transport == nil { - transport = defaultHTTPClient - } - // transport policy must always be the last in the slice - policies = append(policies, policyFunc(httpHeaderPolicy), policyFunc(bodyDownloadPolicy), transportPolicy{trans: transport}) - return Pipeline{ - policies: policies, - } -} - -// Do is called for each and every HTTP request. It passes the request through all -// the Policy objects (which can transform the Request's URL/query parameters/headers) -// and ultimately sends the transformed HTTP request over the network. -func (p Pipeline) Do(req *Request) (*http.Response, error) { - if err := req.valid(); err != nil { - return nil, err - } - req.policies = p.policies - return req.Next() -} - -// ReadSeekCloser is the interface that groups the io.ReadCloser and io.Seeker interfaces. -type ReadSeekCloser interface { - io.ReadCloser - io.Seeker -} - -type nopCloser struct { - io.ReadSeeker -} - -func (n nopCloser) Close() error { - return nil -} - -// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. -func NopCloser(rs io.ReadSeeker) ReadSeekCloser { - return nopCloser{rs} -} - // holds sentinel values used to send nulls var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{} @@ -159,3 +59,6 @@ func IsNullValue(v interface{}) bool { // no sentinel object for this *t return false } + +// Poller encapsulates state and logic for polling on long-running operations. +type Poller = pollers.Poller diff --git a/sdk/azcore/credential.go b/sdk/azcore/credential.go index 441f3d0bc37c..d55a356964f9 100644 --- a/sdk/azcore/credential.go +++ b/sdk/azcore/credential.go @@ -9,31 +9,23 @@ package azcore import ( "context" "time" -) -// AuthenticationOptions contains various options used to create a credential policy. -type AuthenticationOptions struct { - // TokenRequest is a TokenRequestOptions that includes a scopes field which contains - // the list of OAuth2 authentication scopes used when requesting a token. - // This field is ignored for other forms of authentication (e.g. shared key). - TokenRequest TokenRequestOptions - // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate - // in cross-tenant applications. - AuxiliaryTenants []string -} + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) // Credential represents any credential type. type Credential interface { // AuthenticationPolicy returns a policy that requests the credential and applies it to the HTTP request. - NewAuthenticationPolicy(options AuthenticationOptions) Policy + NewAuthenticationPolicy(options runtime.AuthenticationOptions) policy.Policy } // credentialFunc is a type that implements the Credential interface. // Use this type when implementing a stateless credential as a first-class function. -type credentialFunc func(options AuthenticationOptions) Policy +type credentialFunc func(options runtime.AuthenticationOptions) policy.Policy // AuthenticationPolicy implements the Credential interface on credentialFunc. -func (cf credentialFunc) NewAuthenticationPolicy(options AuthenticationOptions) Policy { +func (cf credentialFunc) NewAuthenticationPolicy(options runtime.AuthenticationOptions) policy.Policy { return cf(options) } @@ -41,7 +33,7 @@ func (cf credentialFunc) NewAuthenticationPolicy(options AuthenticationOptions) type TokenCredential interface { Credential // GetToken requests an access token for the specified set of scopes. - GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) + GetToken(ctx context.Context, options policy.TokenRequestOptions) (*AccessToken, error) } // AccessToken represents an Azure service bearer access token with expiry information. @@ -49,12 +41,3 @@ type AccessToken struct { Token string ExpiresOn time.Time } - -// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. -type TokenRequestOptions struct { - // Scopes contains the list of permission scopes required for the token. - Scopes []string - // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set - // it will override the tenant ID that was added at credential creation time. - TenantID string -} diff --git a/sdk/azcore/error.go b/sdk/azcore/error.go deleted file mode 100644 index e2547faf808c..000000000000 --- a/sdk/azcore/error.go +++ /dev/null @@ -1,70 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "net/http" -) - -var ( - // StackFrameCount contains the number of stack frames to include when a trace is being collected. - StackFrameCount = 32 -) - -// HTTPResponse provides access to an HTTP response when available. -// Errors returned from failed API calls will implement this interface. -// Use errors.As() to access this interface in the error chain. -// If there was no HTTP response then this interface will be omitted -// from any error in the chain. -type HTTPResponse interface { - RawResponse() *http.Response -} - -// NonRetriableError represents a non-transient error. This works in -// conjunction with the retry policy, indicating that the error condition -// is idempotent, so no retries will be attempted. -// Use errors.As() to access this interface in the error chain. -type NonRetriableError interface { - error - NonRetriable() -} - -// NewResponseError wraps the specified error with an error that provides access to an HTTP response. -// If an HTTP request returns a non-successful status code, wrap the response and the associated error -// in this error type so that callers can access the underlying *http.Response as required. -// DO NOT wrap failed HTTP requests that returned an error and no response with this type. -func NewResponseError(inner error, resp *http.Response) error { - return &responseError{inner: inner, resp: resp} -} - -type responseError struct { - inner error - resp *http.Response -} - -// Error implements the error interface for type ResponseError. -func (e *responseError) Error() string { - return e.inner.Error() -} - -// Unwrap returns the inner error. -func (e *responseError) Unwrap() error { - return e.inner -} - -// RawResponse returns the HTTP response associated with this error. -func (e *responseError) RawResponse() *http.Response { - return e.resp -} - -// NonRetriable indicates this error is non-transient. -func (e *responseError) NonRetriable() { - // marker method -} - -var _ HTTPResponse = (*responseError)(nil) -var _ NonRetriableError = (*responseError)(nil) diff --git a/sdk/azcore/errors.go b/sdk/azcore/errors.go new file mode 100644 index 000000000000..222c8f85f21b --- /dev/null +++ b/sdk/azcore/errors.go @@ -0,0 +1,26 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +// HTTPResponse provides access to an HTTP response when available. +// Errors returned from failed API calls will implement this interface. +// Use errors.As() to access this interface in the error chain. +// If there was no HTTP response then this interface will be omitted +// from any error in the chain. +type HTTPResponse interface { + RawResponse() *http.Response +} + +var _ HTTPResponse = (*shared.ResponseError)(nil) +var _ errorinfo.NonRetriable = (*shared.ResponseError)(nil) diff --git a/sdk/azcore/example_test.go b/sdk/azcore/example_test.go index 0506a3f75bac..b13cb08a11fd 100644 --- a/sdk/azcore/example_test.go +++ b/sdk/azcore/example_test.go @@ -8,58 +8,23 @@ package azcore_test import ( - "context" "encoding/json" "fmt" - "io/ioutil" - "log" - "net/http" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" ) -func ExamplePipeline_Do() { - req, err := azcore.NewRequest(context.Background(), http.MethodGet, "https://github.com/robots.txt") - if err != nil { - log.Fatal(err) - } - pipeline := azcore.NewPipeline(nil) - resp, err := pipeline.Do(req) - if err != nil { - log.Fatal(err) - } - robots, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Fatal(err) - } - fmt.Printf("%s", robots) -} - -func ExampleRequest_SetBody() { - req, err := azcore.NewRequest(context.Background(), http.MethodPut, "https://contoso.com/some/endpoint") - if err != nil { - log.Fatal(err) - } - body := strings.NewReader("this is seekable content to be uploaded") - err = req.SetBody(azcore.NopCloser(body), "text/plain") - if err != nil { - log.Fatal(err) - } -} - // false positive by linter func ExampleSetClassifications() { //nolint:govet // only log HTTP requests and responses - azlog.SetClassifications(azlog.Request, azlog.Response) + log.SetClassifications(log.Request, log.Response) } // false positive by linter func ExampleSetListener() { //nolint:govet // a simple logger that writes to stdout - azlog.SetListener(func(cls azlog.Classification, msg string) { + log.SetListener(func(cls log.Classification, msg string) { fmt.Printf("%s: %s\n", cls, msg) }) } diff --git a/sdk/azcore/internal/pipeline/pipeline.go b/sdk/azcore/internal/pipeline/pipeline.go new file mode 100644 index 000000000000..e2c9f115a1d7 --- /dev/null +++ b/sdk/azcore/internal/pipeline/pipeline.go @@ -0,0 +1,93 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "errors" + "fmt" + "net/http" + + "golang.org/x/net/http/httpguts" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +type Policy interface { + // Do applies the policy to the specified Request. When implementing a Policy, mutate the + // request before calling req.Next() to move on to the next policy, and respond to the result + // before returning to the caller. + Do(req *Request) (*http.Response, error) +} + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline struct { + policies []Policy +} + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +type Transporter interface { + // Do sends the HTTP request and returns the HTTP response or error. + Do(req *http.Request) (*http.Response, error) +} + +// used to adapt a TransportPolicy to a Policy +type transportPolicy struct { + trans Transporter +} + +func (tp transportPolicy) Do(req *Request) (*http.Response, error) { + if tp.trans == nil { + return nil, errors.New("missing transporter") + } + resp, err := tp.trans.Do(req.Raw()) + if err != nil { + return nil, err + } else if resp == nil { + // there was no response and no error (rare but can happen) + // this ensures the retry policy will retry the request + return nil, errors.New("received nil response") + } + return resp, nil +} + +// NewPipeline creates a new Pipeline object from the specified Policies. +func NewPipeline(transport Transporter, policies ...Policy) Pipeline { + // transport policy must always be the last in the slice + policies = append(policies, transportPolicy{trans: transport}) + return Pipeline{ + policies: policies, + } +} + +// Do is called for each and every HTTP request. It passes the request through all +// the Policy objects (which can transform the Request's URL/query parameters/headers) +// and ultimately sends the transformed HTTP request over the network. +func (p Pipeline) Do(req *Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + // check copied from Transport.roundTrip() + for k, vv := range req.Raw().Header { + if !httpguts.ValidHeaderFieldName(k) { + if req.Raw().Body != nil { + req.Raw().Body.Close() + } + return nil, fmt.Errorf("invalid header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + if req.Raw().Body != nil { + req.Raw().Body.Close() + } + return nil, fmt.Errorf("invalid header field value %q for key %v", v, k) + } + } + } + req.policies = p.policies + return req.Next() +} diff --git a/sdk/azcore/internal/pipeline/pipeline_test.go b/sdk/azcore/internal/pipeline/pipeline_test.go new file mode 100644 index 000000000000..81bc22e698af --- /dev/null +++ b/sdk/azcore/internal/pipeline/pipeline_test.go @@ -0,0 +1,103 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "errors" + "net/http" + "testing" +) + +func TestPipelineErrors(t *testing.T) { + pl := NewPipeline(nil) + resp, err := pl.Do(nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req, err := NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req.Raw().Header["Invalid"] = []string{string([]byte{0})} + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req, err = NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + req.Raw().Header["Inv alid"] = []string{"value"} + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +type mockTransport struct { + succeed bool + both bool +} + +func (m *mockTransport) Do(*http.Request) (*http.Response, error) { + if m.both { + return nil, nil + } + if m.succeed { + return &http.Response{StatusCode: http.StatusOK}, nil + } + return nil, errors.New("failed") +} + +func TestPipelineDo(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + tp := mockTransport{succeed: true} + pl := NewPipeline(&tp) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if sc := resp.StatusCode; sc != http.StatusOK { + t.Fatalf("unexpected status code %d", sc) + } + tp.succeed = false + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tp.both = true + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} diff --git a/sdk/azcore/internal/pipeline/request.go b/sdk/azcore/internal/pipeline/request.go new file mode 100644 index 000000000000..e261f30429a2 --- /dev/null +++ b/sdk/azcore/internal/pipeline/request.go @@ -0,0 +1,169 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// PolicyFunc is a type that implements the Policy interface. +// Use this type when implementing a stateless policy as a first-class function. +type PolicyFunc func(*Request) (*http.Response, error) + +// Do implements the Policy interface on PolicyFunc. +func (pf PolicyFunc) Do(req *Request) (*http.Response, error) { + return pf(req) +} + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use NewRequest() instead. +type Request struct { + req *http.Request + body io.ReadSeekCloser + policies []Policy + values opValues +} + +type opValues map[reflect.Type]interface{} + +// Set adds/changes a value +func (ov opValues) set(value interface{}) { + ov[reflect.TypeOf(value)] = value +} + +// Get looks for a value set by SetValue first +func (ov opValues) get(value interface{}) bool { + v, ok := ov[reflect.ValueOf(value).Elem().Type()] + if ok { + reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) + } + return ok +} + +// NewRequest creates a new Request with the specified input. +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { + req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) + if err != nil { + return nil, err + } + if req.URL.Host == "" { + return nil, errors.New("no Host in request URL") + } + if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { + return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) + } + return &Request{req: req}, nil +} + +// Body returns the original body specified when the Request was created. +func (req *Request) Body() io.ReadSeekCloser { + return req.body +} + +// Raw returns the underlying HTTP request. +func (req *Request) Raw() *http.Request { + return req.req +} + +// Next calls the next policy in the pipeline. +// If there are no more policies, nil and an error are returned. +// This method is intended to be called from pipeline policies. +// To send a request through a pipeline call Pipeline.Do(). +func (req *Request) Next() (*http.Response, error) { + if len(req.policies) == 0 { + return nil, errors.New("no more policies") + } + nextPolicy := req.policies[0] + nextReq := *req + nextReq.policies = nextReq.policies[1:] + return nextPolicy.Do(&nextReq) +} + +// SetOperationValue adds/changes a mutable key/value associated with a single operation. +func (req *Request) SetOperationValue(value interface{}) { + if req.values == nil { + req.values = opValues{} + } + req.values.set(value) +} + +// OperationValue looks for a value set by SetOperationValue(). +func (req *Request) OperationValue(value interface{}) bool { + if req.values == nil { + return false + } + return req.values.get(value) +} + +// SetBody sets the specified ReadSeekCloser as the HTTP request body. +func (req *Request) SetBody(body io.ReadSeekCloser, contentType string) error { + // Set the body and content length. + size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size + if err != nil { + return err + } + if size == 0 { + body.Close() + return nil + } + _, err = body.Seek(0, io.SeekStart) + if err != nil { + return err + } + req.Raw().GetBody = func() (io.ReadCloser, error) { + _, err := body.Seek(0, io.SeekStart) // Seek back to the beginning of the stream + return body, err + } + // keep a copy of the original body. this is to handle cases + // where req.Body is replaced, e.g. httputil.DumpRequest and friends. + req.body = body + req.req.Body = body + req.req.ContentLength = size + req.req.Header.Set(shared.HeaderContentType, contentType) + req.req.Header.Set(shared.HeaderContentLength, strconv.FormatInt(size, 10)) + return nil +} + +// SkipBodyDownload will disable automatic downloading of the response body. +func (req *Request) SkipBodyDownload() { + req.SetOperationValue(shared.BodyDownloadPolicyOpValues{Skip: true}) +} + +// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. +func (req *Request) RewindBody() error { + if req.body != nil { + // Reset the stream back to the beginning and restore the body + _, err := req.body.Seek(0, io.SeekStart) + req.req.Body = req.body + return err + } + return nil +} + +// Close closes the request body. +func (req *Request) Close() error { + if req.body == nil { + return nil + } + return req.body.Close() +} + +// Clone returns a deep copy of the request with its context changed to ctx. +func (req *Request) Clone(ctx context.Context) *Request { + r2 := Request{} + r2 = *req + r2.req = req.req.Clone(ctx) + return &r2 +} diff --git a/sdk/azcore/internal/pipeline/request_test.go b/sdk/azcore/internal/pipeline/request_test.go new file mode 100644 index 000000000000..4677417861cc --- /dev/null +++ b/sdk/azcore/internal/pipeline/request_test.go @@ -0,0 +1,139 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const testURL = "http://test.contoso.com/" + +func TestNewRequest(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + if m := req.Raw().Method; m != http.MethodPost { + t.Fatalf("unexpected method %s", m) + } + type myValue struct{} + var mv myValue + if req.OperationValue(&mv) { + t.Fatal("expected missing custom operation value") + } + req.SetOperationValue(myValue{}) + if !req.OperationValue(&mv) { + t.Fatal("missing custom operation value") + } +} + +func TestRequestPolicies(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + resp, err := req.Next() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req.policies = []Policy{} + resp, err = req.Next() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + testPolicy := func(*Request) (*http.Response, error) { + return &http.Response{}, nil + } + req.policies = []Policy{PolicyFunc(testPolicy)} + resp, err = req.Next() + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("unexpected nil response") + } +} + +func TestRequestBody(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + req.SkipBodyDownload() + if err := req.RewindBody(); err != nil { + t.Fatal(err) + } + if err := req.Close(); err != nil { + t.Fatal(err) + } + if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + t.Fatal(err) + } + if err := req.RewindBody(); err != nil { + t.Fatal(err) + } + if err := req.Close(); err != nil { + t.Fatal(err) + } +} + +func TestRequestClone(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + req.SkipBodyDownload() + if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + t.Fatal(err) + } + clone := req.Clone(context.Background()) + var skip shared.BodyDownloadPolicyOpValues + if !clone.OperationValue(&skip) { + t.Fatal("missing operation value") + } + if !skip.Skip { + t.Fatal("wrong operation value") + } + if clone.body == nil { + t.Fatal("missing body") + } +} + +func TestNewRequestFail(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodOptions, "://test.contoso.com/") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } + req, err = NewRequest(context.Background(), http.MethodPatch, "/missing/the/host") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } + req, err = NewRequest(context.Background(), http.MethodPatch, "mailto://nobody.contoso.com") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } +} diff --git a/sdk/azcore/internal/pollers/loc/loc.go b/sdk/azcore/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..357a3d7d5966 --- /dev/null +++ b/sdk/azcore/internal/pollers/loc/loc.go @@ -0,0 +1,80 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Location" + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + CurState int `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(shared.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + return &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: locURL, + CurState: resp.StatusCode, + }, nil +} + +func (p *Poller) URL() string { + return p.PollURL +} + +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +func (p *Poller) Update(resp *http.Response) error { + // if the endpoint returned a location header, update cached value + if loc := resp.Header.Get(shared.HeaderLocation); loc != "" { + p.PollURL = loc + } + p.CurState = resp.StatusCode + return nil +} + +func (*Poller) FinalGetURL() string { + return "" +} + +func (p *Poller) Status() string { + if p.CurState == http.StatusAccepted { + return pollers.StatusInProgress + } else if p.CurState > 199 && p.CurState < 300 { + // any 2xx other than a 202 indicates success + return pollers.StatusSucceeded + } + return pollers.StatusFailed +} diff --git a/sdk/azcore/internal/pollers/loc/loc_test.go b/sdk/azcore/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..6fa70aef14d2 --- /dev/null +++ b/sdk/azcore/internal/pollers/loc/loc_test.go @@ -0,0 +1,136 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakeLocationURL = "https://foo.bar.baz/status" + fakeLocationURL2 = "https://foo.bar.baz/status/other" +) + +func initialResponse() *http.Response { + return &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + if !Applicable(resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != pollers.StatusInProgress { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeLocationURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFail(t *testing.T) { + resp := initialResponse() + poller, err := New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} + +func TestUpdateSucceeded(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.URL(); u != fakeLocationURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(&http.Response{StatusCode: http.StatusOK}); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateFailed(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.URL(); u != fakeLocationURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(&http.Response{StatusCode: http.StatusConflict}); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusFailed { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/internal/pollers/op/op.go b/sdk/azcore/internal/pollers/op/op.go new file mode 100644 index 000000000000..730a85dfa795 --- /dev/null +++ b/sdk/azcore/internal/pollers/op/op.go @@ -0,0 +1,132 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package op + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Operation-Location" + +// Applicable returns true if the LRO is using Operation-Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderOperationLocation) != "" +} + +// Poller is an LRO poller that uses the Operation-Location pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + LocURL string `json:"locURL"` + FinalGET string `json:"finalGET"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Operation-Location poller.") + opURL := resp.Header.Get(shared.HeaderOperationLocation) + if opURL == "" { + return nil, errors.New("response is missing Operation-Location header") + } + if !pollers.IsValidURL(opURL) { + return nil, fmt.Errorf("invalid Operation-Location URL %s", opURL) + } + locURL := resp.Header.Get(shared.HeaderLocation) + // Location header is optional + if locURL != "" && !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid Location URL %s", locURL) + } + // default initial state to InProgress. if the + // service sent us a status then use that instead. + curState := pollers.StatusInProgress + status, err := getValue(resp, "status") + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return nil, err + } + if status != "" { + curState = status + } + // calculate the tentative final GET URL. + // can change if we receive a resourceLocation. + // it's ok for it to be empty in some cases. + finalGET := "" + if resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut { + finalGET = resp.Request.URL.String() + } else if resp.Request.Method == http.MethodPost && locURL != "" { + finalGET = locURL + } + return &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: opURL, + LocURL: locURL, + FinalGET: finalGET, + CurState: curState, + }, nil +} + +func (p *Poller) URL() string { + return p.PollURL +} + +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +func (p *Poller) Update(resp *http.Response) error { + status, err := getValue(resp, "status") + if err != nil { + return err + } else if status == "" { + return errors.New("the response did not contain a status") + } + p.CurState = status + // if the endpoint returned an operation-location header, update cached value + if opLoc := resp.Header.Get(shared.HeaderOperationLocation); opLoc != "" { + p.PollURL = opLoc + } + // check for resourceLocation + resLoc, err := getValue(resp, "resourceLocation") + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return err + } else if resLoc != "" { + p.FinalGET = resLoc + } + return nil +} + +func (p *Poller) FinalGetURL() string { + return p.FinalGET +} + +func (p *Poller) Status() string { + return p.CurState +} + +func getValue(resp *http.Response, val string) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + v, ok := jsonBody[val] + if !ok { + // it might be ok if the field doesn't exist, the caller must make that determination + return "", nil + } + vv, ok := v.(string) + if !ok { + return "", fmt.Errorf("the %s value %v was not in string format", val, v) + } + return vv, nil +} diff --git a/sdk/azcore/internal/pollers/op/op_test.go b/sdk/azcore/internal/pollers/op/op_test.go new file mode 100644 index 000000000000..55c3f2253ea4 --- /dev/null +++ b/sdk/azcore/internal/pollers/op/op_test.go @@ -0,0 +1,249 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package op + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/status/updated" + fakeLocationURL = "https://foo.bar.baz/location" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, body io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + Request: req, + } +} + +func createResponse(body io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Operation-Location should not be applicable") + } + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + if !Applicable(resp) { + t.Fatal("having Operation-Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != pollers.StatusInProgress { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected URL %s", u) + } +} + +func TestNewWithInitialStatus(t *testing.T) { + resp := initialResponse(http.MethodPut, strings.NewReader(`{ "status": "Updating" }`)) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewWithPost(t *testing.T) { + resp := initialResponse(http.MethodPost, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != fakeLocationURL { + t.Fatalf("unexpected final get URL %s", u) + } +} + +func TestNewWithDelete(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatalf("unexpected final get URL %s", u) + } +} + +func TestNewFail(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + poller, err := New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderOperationLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} + +func TestUpdateSucceeded(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Running" }`)) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if s := poller.Status(); s != "Running" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected URL %s", u) + } + resp = createResponse(strings.NewReader(`{ "status": "Succeeded" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateResourceLocation(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Succeeded", "resourceLocation": "https://foo.bar.baz/resource2" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } + if u := poller.FinalGetURL(); u != "https://foo.bar.baz/resource2" { + t.Fatalf("unexpected final get url %s", u) + } +} + +func TestUpdateFailed(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Failed" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusFailed { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateMissingStatus(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(http.NoBody) + if err := poller.Update(resp); err == nil { + t.Fatal("unexpected nil error") + } + if poller.Done() { + t.Fatal("unexpected done") + } +} diff --git a/sdk/azcore/internal/pollers/poller.go b/sdk/azcore/internal/pollers/poller.go new file mode 100644 index 000000000000..aca2f3197b72 --- /dev/null +++ b/sdk/azcore/internal/pollers/poller.go @@ -0,0 +1,213 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// KindFromToken extracts the poller kind from the provided token. +// If the pollerID doesn't match what's in the token an error is returned. +func KindFromToken(pollerID, token string) (string, error) { + // unmarshal into JSON object to determine the poller type + obj := map[string]interface{}{} + err := json.Unmarshal([]byte(token), &obj) + if err != nil { + return "", err + } + t, ok := obj["type"] + if !ok { + return "", errors.New("missing type field") + } + tt, ok := t.(string) + if !ok { + return "", fmt.Errorf("invalid type format %T", t) + } + ttID, ttKind, err := DecodeID(tt) + if err != nil { + return "", err + } + // ensure poller types match + if ttID != pollerID { + return "", fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, ttID) + } + return ttKind, nil +} + +// PollerType returns the concrete type of the poller (FOR TESTING PURPOSES). +func PollerType(p *Poller) reflect.Type { + return reflect.TypeOf(p.lro) +} + +// NewPoller creates a Poller from the specified input. +func NewPoller(lro Operation, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) *Poller { + return &Poller{lro: lro, pl: pl, eu: eu, resp: resp} +} + +// Poller encapsulates state and logic for polling on long-running operations. +type Poller struct { + lro Operation + pl pipeline.Pipeline + eu func(*http.Response) error + resp *http.Response + err error +} + +// Done returns true if the LRO has reached a terminal state. +func (l *Poller) Done() bool { + if l.err != nil { + return true + } + return l.lro.Done() +} + +// Poll sends a polling request to the polling endpoint and returns the response or error. +func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { + if l.Done() { + // the LRO has reached a terminal state, don't poll again + if l.resp != nil { + return l.resp, nil + } + return nil, l.err + } + req, err := pipeline.NewRequest(ctx, http.MethodGet, l.lro.URL()) + if err != nil { + return nil, err + } + resp, err := l.pl.Do(req) + if err != nil { + // don't update the poller for failed requests + return nil, err + } + defer resp.Body.Close() + if !StatusCodeValid(resp) { + // the LRO failed. unmarshall the error and update state + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + if err = l.lro.Update(resp); err != nil { + return nil, err + } + l.resp = resp + log.Writef(log.LongRunningOperation, "Status %s", l.lro.Status()) + if Failed(l.lro.Status()) { + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + return l.resp, nil +} + +// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. +func (l *Poller) ResumeToken() (string, error) { + if l.Done() { + return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") + } + b, err := json.Marshal(l.lro) + if err != nil { + return "", err + } + return string(b), nil +} + +// FinalResponse will perform a final GET request and return the final HTTP response for the polling +// operation and unmarshall the content of the payload into the respType interface that is provided. +func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { + if !l.Done() { + return nil, errors.New("cannot return a final response from a poller in a non-terminal state") + } + // update l.resp with the content from final GET if applicable + if u := l.lro.FinalGetURL(); u != "" { + log.Write(log.LongRunningOperation, "Performing final GET.") + req, err := pipeline.NewRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + resp, err := l.pl.Do(req) + if err != nil { + return nil, err + } + if !StatusCodeValid(resp) { + return nil, l.eu(resp) + } + l.resp = resp + } + // if there's nothing to unmarshall into or no response body just return the final response + if respType == nil { + return l.resp, nil + } else if l.resp.StatusCode == http.StatusNoContent || l.resp.ContentLength == 0 { + log.Write(log.LongRunningOperation, "final response specifies a response type but no payload was received") + return l.resp, nil + } + body, err := ioutil.ReadAll(l.resp.Body) + l.resp.Body.Close() + if err != nil { + return nil, err + } + if err = json.Unmarshal(body, respType); err != nil { + return nil, err + } + return l.resp, nil +} + +// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, +// then return the final HTTP response for the polling operation and unmarshal the content of the payload +// into the respType interface that is provided. +// freq - the time to wait between polling intervals if the endpoint doesn't send a Retry-After header. +// A good starting value is 30 seconds. Note that some resources might benefit from a different value. +func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { + start := time.Now() + logPollUntilDoneExit := func(v interface{}) { + log.Writef(log.LongRunningOperation, "END PollUntilDone() for %T: %v, total time: %s", l.lro, v, time.Since(start)) + } + log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) + if l.resp != nil { + // initial check for a retry-after header existing on the initial response + if retryAfter := shared.RetryAfter(l.resp); retryAfter > 0 { + log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) + if err := shared.Delay(ctx, retryAfter); err != nil { + logPollUntilDoneExit(err) + return nil, err + } + } + } + // begin polling the endpoint until a terminal state is reached + for { + resp, err := l.Poll(ctx) + if err != nil { + logPollUntilDoneExit(err) + return nil, err + } + if l.Done() { + logPollUntilDoneExit(l.lro.Status()) + return l.FinalResponse(ctx, respType) + } + d := freq + if retryAfter := shared.RetryAfter(resp); retryAfter > 0 { + log.Writef(log.LongRunningOperation, "Retry-After delay for %s", retryAfter.String()) + d = retryAfter + } else { + log.Writef(log.LongRunningOperation, "delay for %s", d.String()) + } + if err = shared.Delay(ctx, d); err != nil { + logPollUntilDoneExit(err) + return nil, err + } + } +} diff --git a/sdk/azcore/internal/pollers/poller_test.go b/sdk/azcore/internal/pollers/poller_test.go new file mode 100644 index 000000000000..01dc426cab51 --- /dev/null +++ b/sdk/azcore/internal/pollers/poller_test.go @@ -0,0 +1,256 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestKindFromToken(t *testing.T) { + const tk = `{ "type": "pollerID;kind" }` + k, err := KindFromToken("pollerID", tk) + if err != nil { + t.Fatal(err) + } + if k != "kind" { + t.Fatalf("unexpected kind %s", k) + } + k, err = KindFromToken("mismatched", tk) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } +} + +func TestKindFromTokenInvalid(t *testing.T) { + const tk1 = `{ "missing": "type" }` + k, err := KindFromToken("mismatched", tk1) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } + const tk2 = `{ "type": false }` + k, err = KindFromToken("mismatched", tk2) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } + const tk3 = `{ "type": "pollerID;kind;extra" }` + k, err = KindFromToken("mismatched", tk3) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } +} + +// simple status code-based poller +type fakePoller struct { + Ep string + Fg string + Code int +} + +func (f *fakePoller) Done() bool { + return f.Code == http.StatusOK || f.Code == http.StatusNoContent +} + +func (f *fakePoller) Update(resp *http.Response) error { + f.Code = resp.StatusCode + return nil +} + +func (f *fakePoller) FinalGetURL() string { + return f.Fg +} + +func (f *fakePoller) URL() string { + return f.Ep +} + +func (f *fakePoller) Status() string { + switch f.Code { + case http.StatusAccepted: + return StatusInProgress + case http.StatusOK, http.StatusNoContent: + return StatusSucceeded + case http.StatusCreated: + return StatusCanceled + default: + return StatusFailed + } +} + +func TestNewPoller(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{}, + } + firstResp.Header.Set(shared.HeaderRetryAfter, "1") + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + if p.Done() { + t.Fatal("unexpected done") + } + resp, err := p.FinalResponse(context.Background(), nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tk, err := p.ResumeToken() + if err != nil { + t.Fatal(err) + } + if tk == "" { + t.Fatal("unexpected empty resume token") + } + resp, err = p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + tk, err = p.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") + } + if tk != "" { + t.Fatal("expected empty resume token") + } +} + +func TestNewPollerWithFinalGET(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithHeader(shared.HeaderRetryAfter, "1")) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) // terminal + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "shape": "round" }`))) // final GET + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL(), Fg: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + if p.Done() { + t.Fatal("unexpected done") + } + type widget struct { + Shape string `json:"shape"` + } + var w widget + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, &w) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if w.Shape != "round" { + t.Fatalf("unexpected result %s", w.Shape) + } + resp, err = p.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestNewPollerFail1(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "failed" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestNewPollerFail2(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusCreated)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "failed" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestNewPollerError(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendError(errors.New("fatal")) + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "fatal" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} diff --git a/sdk/azcore/internal/pollers/util.go b/sdk/azcore/internal/pollers/util.go new file mode 100644 index 000000000000..dca70b5a596b --- /dev/null +++ b/sdk/azcore/internal/pollers/util.go @@ -0,0 +1,99 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + StatusSucceeded = "Succeeded" + StatusCanceled = "Canceled" + StatusFailed = "Failed" + StatusInProgress = "InProgress" +) + +// Operation abstracts the differences between concrete poller types. +type Operation interface { + Done() bool + Update(resp *http.Response) error + FinalGetURL() string + URL() string + Status() string +} + +// IsTerminalState returns true if the LRO's state is terminal. +func IsTerminalState(s string) bool { + return strings.EqualFold(s, StatusSucceeded) || strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// Failed returns true if the LRO's state is terminal failure. +func Failed(s string) bool { + return strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// returns true if the LRO response contains a valid HTTP status code +func StatusCodeValid(resp *http.Response) bool { + return shared.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) +} + +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + +const idSeparator = ";" + +// MakeID returns the poller ID from the provided values. +func MakeID(pollerID string, kind string) string { + return fmt.Sprintf("%s%s%s", pollerID, idSeparator, kind) +} + +// DecodeID decodes the poller ID, returning [pollerID, kind] or an error. +func DecodeID(tk string) (string, string, error) { + raw := strings.Split(tk, idSeparator) + // strings.Split will include any/all whitespace strings, we want to omit those + parts := []string{} + for _, r := range raw { + if s := strings.TrimSpace(r); s != "" { + parts = append(parts, s) + } + } + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid token %s", tk) + } + return parts[0], parts[1], nil +} + +// used if the operation synchronously completed +type NopPoller struct{} + +func (*NopPoller) URL() string { + return "" +} + +func (*NopPoller) Done() bool { + return true +} + +func (*NopPoller) Update(*http.Response) error { + return nil +} + +func (*NopPoller) FinalGetURL() string { + return "" +} + +func (*NopPoller) Status() string { + return StatusSucceeded +} diff --git a/sdk/azcore/internal/pollers/util_test.go b/sdk/azcore/internal/pollers/util_test.go new file mode 100644 index 000000000000..04932bfcdf73 --- /dev/null +++ b/sdk/azcore/internal/pollers/util_test.go @@ -0,0 +1,169 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "net/http" + "strings" + "testing" +) + +func TestIsTerminalState(t *testing.T) { + if IsTerminalState("Updating") { + t.Fatal("Updating is not a terminal state") + } + if !IsTerminalState("Succeeded") { + t.Fatal("Succeeded is a terminal state") + } + if !IsTerminalState("failed") { + t.Fatal("failed is a terminal state") + } + if !IsTerminalState("canceled") { + t.Fatal("canceled is a terminal state") + } +} + +func TestStatusCodeValid(t *testing.T) { + if !StatusCodeValid(&http.Response{StatusCode: http.StatusOK}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusAccepted}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusCreated}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusNoContent}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusPartialContent}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusBadRequest}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusInternalServerError}) { + t.Fatal("unexpected valid code") + } +} + +func TestMakeID(t *testing.T) { + const ( + pollerID = "pollerID" + kind = "kind" + ) + id := MakeID(pollerID, kind) + parts := strings.Split(id, idSeparator) + if l := len(parts); l != 2 { + t.Fatalf("unexpected length %d", l) + } + if p := parts[0]; p != pollerID { + t.Fatalf("unexpected poller ID %s", p) + } + if p := parts[1]; p != kind { + t.Fatalf("unexpected poller kind %s", p) + } +} + +func TestDecodeID(t *testing.T) { + _, _, err := DecodeID("") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid_token;") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID(" ;invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid;token;too") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err := DecodeID("pollerID;kind") + if err != nil { + t.Fatal(err) + } + if id != "pollerID" { + t.Fatalf("unexpected ID %s", id) + } + if kind != "kind" { + t.Fatalf("unexpected kin %s", kind) + } +} + +func TestIsValidURL(t *testing.T) { + if IsValidURL("/foo") { + t.Fatal("unexpected valid URL") + } + if !IsValidURL("https://foo.bar/baz") { + t.Fatal("expected valid URL") + } +} + +func TestFailed(t *testing.T) { + if Failed("Succeeded") || Failed("Updating") { + t.Fatal("unexpected failure") + } + if !Failed("failed") { + t.Fatal("expected failure") + } +} + +func TestNopPoller(t *testing.T) { + np := NopPoller{} + if !np.Done() { + t.Fatal("expected done") + } + if np.FinalGetURL() != "" { + t.Fatal("expected empty final get URL") + } + if np.Status() != StatusSucceeded { + t.Fatal("expected Succeeded") + } + if np.URL() != "" { + t.Fatal("expected empty URL") + } + if err := np.Update(nil); err != nil { + t.Fatal(err) + } +} + +/*func TestNewPollerNop(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodPost, srv.URL(), strings.NewReader(successResp)) + resp.StatusCode = http.StatusOK + poller, err := NewPoller("pollerID", "", resp, getPipeline(srv), handleError) + if err != nil { + t.Fatal(err) + } + if _, ok := poller.lro.(*nopPoller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) + } + tk, err := poller.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") + } + if tk != "" { + t.Fatal("expected empty token") + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +}*/ diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go new file mode 100644 index 000000000000..d2862d8efa3f --- /dev/null +++ b/sdk/azcore/internal/shared/constants.go @@ -0,0 +1,34 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +const ( + ContentTypeAppJSON = "application/json" + ContentTypeAppXML = "application/xml" +) + +const ( + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderLocation = "Location" + HeaderOperationLocation = "Operation-Location" + HeaderRetryAfter = "Retry-After" + HeaderUserAgent = "User-Agent" +) + +const ( + DefaultMaxRetries = 3 +) + +const ( + // Module is the name of the calling module used in telemetry data. + Module = "azcore" + + // Version is the semantic version (see http://semver.org) of this module. + Version = "v0.19.0" +) diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go new file mode 100644 index 000000000000..c275221f8886 --- /dev/null +++ b/sdk/azcore/internal/shared/shared.go @@ -0,0 +1,148 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net/http" + "strconv" + "time" +) + +// CtxWithHTTPHeaderKey is used as a context key for adding/retrieving http.Header. +type CtxWithHTTPHeaderKey struct{} + +// CtxWithRetryOptionsKey is used as a context key for adding/retrieving RetryOptions. +type CtxWithRetryOptionsKey struct{} + +type nopCloser struct { + io.ReadSeeker +} + +func (n nopCloser) Close() error { + return nil +} + +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return nopCloser{rs} +} + +// BodyDownloadPolicyOpValues is the struct containing the per-operation values +type BodyDownloadPolicyOpValues struct { + Skip bool +} + +func NewResponseError(inner error, resp *http.Response) error { + return &ResponseError{inner: inner, resp: resp} +} + +type ResponseError struct { + inner error + resp *http.Response +} + +// Error implements the error interface for type ResponseError. +func (e *ResponseError) Error() string { + return e.inner.Error() +} + +// Unwrap returns the inner error. +func (e *ResponseError) Unwrap() error { + return e.inner +} + +// RawResponse returns the HTTP response associated with this error. +func (e *ResponseError) RawResponse() *http.Response { + return e.resp +} + +// NonRetriable indicates this error is non-transient. +func (e *ResponseError) NonRetriable() { + // marker method +} + +// Delay waits for the duration to elapse or the context to be cancelled. +func Delay(ctx context.Context, delay time.Duration) error { + select { + case <-time.After(delay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") + +// GetJSON reads the response body into a raw JSON object. +// It returns ErrNoBody if there was no content. +func GetJSON(resp *http.Response) (map[string]interface{}, error) { + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return nil, err + } + if len(body) == 0 { + return nil, ErrNoBody + } + // put the body back so it's available to others + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + // unmarshall the body to get the value + var jsonBody map[string]interface{} + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + +// RetryAfter returns non-zero if the response contains a Retry-After header value. +func RetryAfter(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + ra := resp.Header.Get(HeaderRetryAfter) + if ra == "" { + return 0 + } + // retry-after values are expressed in either number of + // seconds or an HTTP-date indicating when to try again + if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { + return time.Duration(retryAfter) * time.Second + } else if t, err := time.Parse(time.RFC1123, ra); err == nil { + return time.Until(t) + } + return 0 +} + +// HasStatusCode returns true if the Response's status code is one of the specified values. +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { + return false + } + for _, sc := range statusCodes { + if resp.StatusCode == sc { + return true + } + } + return false +} + +const defaultScope = "/.default" + +// EndpointToScope converts the provided URL endpoint to its default scope. +func EndpointToScope(endpoint string) string { + if endpoint[len(endpoint)-1] != '/' { + endpoint += "/" + } + return endpoint + defaultScope +} diff --git a/sdk/azcore/internal/shared/shared_test.go b/sdk/azcore/internal/shared/shared_test.go new file mode 100644 index 000000000000..225b89cc6bc4 --- /dev/null +++ b/sdk/azcore/internal/shared/shared_test.go @@ -0,0 +1,133 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" +) + +func TestNopCloser(t *testing.T) { + nc := NopCloser(strings.NewReader("foo")) + if err := nc.Close(); err != nil { + t.Fatal(err) + } +} + +type testError struct { + m string +} + +func (t testError) Error() string { + return t.m +} + +func TestNewResponseError(t *testing.T) { + err := NewResponseError(testError{m: "crash"}, &http.Response{StatusCode: http.StatusInternalServerError}) + if s := err.Error(); s != "crash" { + t.Fatalf("unexpected error %s", s) + } + re, ok := err.(*ResponseError) + if !ok { + t.Fatalf("unexpected error type %T", err) + } + re.NonRetriable() // nop + if c := re.RawResponse().StatusCode; c != http.StatusInternalServerError { + t.Fatalf("unexpected status code %d", c) + } + var te testError + if !errors.As(err, &te) { + t.Fatal("unwrap failed") + } +} + +func TestDelay(t *testing.T) { + if err := Delay(context.Background(), 5*time.Millisecond); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := Delay(ctx, 5*time.Minute); err == nil { + t.Fatal("unexpected nil error") + } +} + +func TestGetJSON(t *testing.T) { + j, err := GetJSON(&http.Response{Body: http.NoBody}) + if !errors.Is(err, ErrNoBody) { + t.Fatal(err) + } + if j != nil { + t.Fatal("expected nil json") + } + j, err = GetJSON(&http.Response{Body: ioutil.NopCloser(strings.NewReader(`{ "foo": "bar" }`))}) + if err != nil { + t.Fatal(err) + } + if v := j["foo"]; v != "bar" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestRetryAfter(t *testing.T) { + if RetryAfter(nil) != 0 { + t.Fatal("expected zero duration") + } + resp := &http.Response{ + Header: http.Header{}, + } + if d := RetryAfter(resp); d > 0 { + t.Fatalf("unexpected retry-after value %d", d) + } + resp.Header.Set(HeaderRetryAfter, "300") + d := RetryAfter(resp) + if d <= 0 { + t.Fatal("expected retry-after value from seconds") + } + if d != 300*time.Second { + t.Fatalf("expected 300 seconds, got %d", d/time.Second) + } + atDate := time.Now().Add(600 * time.Second) + resp.Header.Set(HeaderRetryAfter, atDate.Format(time.RFC1123)) + d = RetryAfter(resp) + if d <= 0 { + t.Fatal("expected retry-after value from date") + } + // d will not be exactly 600 seconds but it will be close + if s := d / time.Second; s < 598 || s > 602 { + t.Fatalf("expected ~600 seconds, got %d", s) + } +} + +func TestHasStatusCode(t *testing.T) { + if HasStatusCode(nil, http.StatusAccepted) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{}) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{StatusCode: http.StatusBadGateway}, http.StatusBadRequest) { + t.Fatal("unexpected success") + } + if !HasStatusCode(&http.Response{StatusCode: http.StatusOK}, http.StatusAccepted, http.StatusOK, http.StatusNoContent) { + t.Fatal("unexpected failure") + } +} + +func TestEndpointToScope(t *testing.T) { + if s := EndpointToScope("https://management.microsoftazure.de/"); s != "https://management.microsoftazure.de//.default" { + t.Fatalf("unexpected scope %s", s) + } + if s := EndpointToScope("https://management.usgovcloudapi.net"); s != "https://management.usgovcloudapi.net//.default" { + t.Fatalf("unexpected scope %s", s) + } +} diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go new file mode 100644 index 000000000000..77722339f44f --- /dev/null +++ b/sdk/azcore/policy/policy.go @@ -0,0 +1,96 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "context" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +type Policy = pipeline.Policy + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +type Transporter = pipeline.Transporter + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use runtime.NewRequest() instead. +type Request = pipeline.Request + +// LogOptions configures the logging policy's behavior. +type LogOptions struct { + // IncludeBody indicates if request and response bodies should be included in logging. + // The default value is false. + // NOTE: enabling this can lead to disclosure of sensitive information, use with care. + IncludeBody bool +} + +// RetryOptions configures the retry policy's behavior. +// Call NewRetryOptions() to create an instance with default values. +type RetryOptions struct { + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. + // The default value is three. A value less than zero means one try and no retries. + MaxRetries int32 + + // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. + // This is disabled by default. Specify a value greater than zero to enable. + // NOTE: Setting this to a small value might cause premature HTTP request time-outs. + TryTimeout time.Duration + + // RetryDelay specifies the initial amount of delay to use before retrying an operation. + // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. + // The default value is four seconds. A value less than zero means no delay between retries. + RetryDelay time.Duration + + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. + // Typically the value is greater than or equal to the value specified in RetryDelay. + // The default Value is 120 seconds. A value less than zero means there is no cap. + MaxRetryDelay time.Duration + + // StatusCodes specifies the HTTP status codes that indicate the operation should be retried. + // The default value is the status codes in StatusCodesForRetry. + // Specifying an empty slice will cause retries to happen only for transport errors. + StatusCodes []int +} + +// TelemetryOptions configures the telemetry policy's behavior. +type TelemetryOptions struct { + // ApplicationID is an application-specific identification string used in telemetry. + // It has a maximum length of 24 characters and must not contain any spaces. + ApplicationID string + + // Disabled will prevent the addition of any telemetry data to the User-Agent. + Disabled bool +} + +// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. +type TokenRequestOptions struct { + // Scopes contains the list of permission scopes required for the token. + Scopes []string + // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set + // it will override the tenant ID that was added at credential creation time. + TenantID string +} + +// WithHTTPHeader adds the specified http.Header to the parent context. +// Use this to specify custom HTTP headers at the API-call level. +// Any overlapping headers will have their values replaced with the values specified here. +func WithHTTPHeader(parent context.Context, header http.Header) context.Context { + return context.WithValue(parent, shared.CtxWithHTTPHeaderKey{}, header) +} + +// WithRetryOptions adds the specified RetryOptions to the parent context. +// Use this to specify custom RetryOptions at the API-call level. +func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { + return context.WithValue(parent, shared.CtxWithRetryOptionsKey{}, options) +} diff --git a/sdk/azcore/policy/policy_test.go b/sdk/azcore/policy/policy_test.go new file mode 100644 index 000000000000..65bcf125cbe4 --- /dev/null +++ b/sdk/azcore/policy/policy_test.go @@ -0,0 +1,54 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "context" + "math" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +func TestWithHTTPHeader(t *testing.T) { + const ( + key = "some" + val = "thing" + ) + input := http.Header{} + input.Set(key, val) + ctx := WithHTTPHeader(context.Background(), input) + if ctx == nil { + t.Fatal("nil context") + } + raw := ctx.Value(shared.CtxWithHTTPHeaderKey{}) + header, ok := raw.(http.Header) + if !ok { + t.Fatalf("unexpected type %T", raw) + } + if v := header.Get(key); v != val { + t.Fatalf("unexpected value %s", v) + } +} + +func TestWithRetryOptions(t *testing.T) { + ctx := WithRetryOptions(context.Background(), RetryOptions{ + MaxRetries: math.MaxInt32, + }) + if ctx == nil { + t.Fatal("nil context") + } + raw := ctx.Value(shared.CtxWithRetryOptionsKey{}) + opts, ok := raw.(RetryOptions) + if !ok { + t.Fatalf("unexpected type %T", raw) + } + if opts.MaxRetries != math.MaxInt32 { + t.Fatalf("unexpected value %d", opts.MaxRetries) + } +} diff --git a/sdk/azcore/policy_anonymous_credential.go b/sdk/azcore/policy_anonymous_credential.go index 496b976f561c..45f3e993b118 100644 --- a/sdk/azcore/policy_anonymous_credential.go +++ b/sdk/azcore/policy_anonymous_credential.go @@ -6,13 +6,19 @@ package azcore -import "net/http" +import ( + "net/http" -func anonCredAuthPolicyFunc(AuthenticationOptions) Policy { - return policyFunc(anonCredPolicyFunc) + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +func anonCredAuthPolicyFunc(runtime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(anonCredPolicyFunc) } -func anonCredPolicyFunc(req *Request) (*http.Response, error) { +func anonCredPolicyFunc(req *policy.Request) (*http.Response, error) { return req.Next() } diff --git a/sdk/azcore/policy_anonymous_credential_test.go b/sdk/azcore/policy_anonymous_credential_test.go index e625470dbcf7..3f6d88a743d5 100644 --- a/sdk/azcore/policy_anonymous_credential_test.go +++ b/sdk/azcore/policy_anonymous_credential_test.go @@ -12,6 +12,7 @@ import ( "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -19,8 +20,8 @@ func TestAnonymousCredential(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK)) - pl := NewPipeline(srv, NewAnonymousCredential().NewAuthenticationPolicy(AuthenticationOptions{})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv, NewAnonymousCredential().NewAuthenticationPolicy(runtime.AuthenticationOptions{})) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -28,7 +29,7 @@ func TestAnonymousCredential(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !reflect.DeepEqual(req.Header, resp.Request.Header) { + if !reflect.DeepEqual(req.Raw().Header, resp.Request.Header) { t.Fatal("unexpected modification to request headers") } } diff --git a/sdk/azcore/policy_http_header.go b/sdk/azcore/policy_http_header.go deleted file mode 100644 index ba7650280d9f..000000000000 --- a/sdk/azcore/policy_http_header.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "context" - "net/http" -) - -// used as a context key for adding/retrieving http.Header -type ctxWithHTTPHeader struct{} - -// newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request -func httpHeaderPolicy(req *Request) (*http.Response, error) { - // check if any custom HTTP headers have been specified - if header := req.Context().Value(ctxWithHTTPHeader{}); header != nil { - for k, v := range header.(http.Header) { - // use Set to replace any existing value - // it also canonicalizes the header key - req.Header.Set(k, v[0]) - // add any remaining values - for i := 1; i < len(v); i++ { - req.Header.Add(k, v[i]) - } - } - } - return req.Next() -} - -// WithHTTPHeader adds the specified http.Header to the parent context. -// Use this to specify custom HTTP headers at the API-call level. -// Any overlapping headers will have their values replaced with the values specified here. -func WithHTTPHeader(parent context.Context, header http.Header) context.Context { - return context.WithValue(parent, ctxWithHTTPHeader{}, header) -} diff --git a/sdk/azcore/poller.go b/sdk/azcore/poller.go deleted file mode 100644 index 340cb0859084..000000000000 --- a/sdk/azcore/poller.go +++ /dev/null @@ -1,448 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "strconv" - "strings" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/internal/log" -) - -// NewPoller creates a Poller based on the provided initial response. -// pollerID - a unique identifier for an LRO, it's usually the client.Method string. -// NOTE: this is only meant for internal use in generated code. -func NewPoller(pollerID string, resp *http.Response, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { - // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). - // ideally the codegen should return an error if the initial response failed and not even create a poller. - if !lroStatusCodeValid(resp) { - return nil, errors.New("the operation failed or was cancelled") - } - opLoc := resp.Header.Get(headerOperationLocation) - loc := resp.Header.Get(headerLocation) - // in the case of both headers, always prefer the operation-location header - if opLoc != "" { - return &Poller{ - lro: newOpPoller(pollerID, opLoc, loc, resp), - pl: pl, - eu: eu, - resp: resp, - }, nil - } - if loc != "" { - return &Poller{ - lro: newLocPoller(pollerID, loc, resp.StatusCode), - pl: pl, - eu: eu, - resp: resp, - }, nil - } - return &Poller{lro: &nopPoller{}, resp: resp}, nil -} - -// NewPollerFromResumeToken creates a Poller from a resume token string. -// pollerID - a unique identifier for an LRO, it's usually the client.Method string. -// NOTE: this is only meant for internal use in generated code. -func NewPollerFromResumeToken(pollerID string, token string, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { - // unmarshal into JSON object to determine the poller type - obj := map[string]interface{}{} - err := json.Unmarshal([]byte(token), &obj) - if err != nil { - return nil, err - } - t, ok := obj["type"] - if !ok { - return nil, errors.New("missing type field") - } - tt, ok := t.(string) - if !ok { - return nil, fmt.Errorf("invalid type format %T", t) - } - // the type is encoded as "pollerType;lroPoller" - sem := strings.LastIndex(tt, ";") - if sem < 0 { - return nil, fmt.Errorf("invalid poller type %s", tt) - } - // ensure poller types match - if received := tt[:sem]; received != pollerID { - return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, received) - } - // now rehydrate the poller based on the encoded poller type - var lro lroPoller - switch pt := tt[sem+1:]; pt { - case "opPoller": - lro = &opPoller{} - case "locPoller": - lro = &locPoller{} - default: - return nil, fmt.Errorf("unhandled lroPoller type %s", pt) - } - if err = json.Unmarshal([]byte(token), lro); err != nil { - return nil, err - } - return &Poller{lro: lro, pl: pl, eu: eu}, nil -} - -// Poller encapsulates state and logic for polling on long-running operations. -// NOTE: this is only meant for internal use in generated code. -type Poller struct { - lro lroPoller - pl Pipeline - eu func(*http.Response) error - resp *http.Response - err error -} - -// Done returns true if the LRO has reached a terminal state. -func (l *Poller) Done() bool { - if l.err != nil { - return true - } - return l.lro.Done() -} - -// Poll sends a polling request to the polling endpoint and returns the response or error. -func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { - if l.Done() { - // the LRO has reached a terminal state, don't poll again - if l.resp != nil { - return l.resp, nil - } - return nil, l.err - } - req, err := NewRequest(ctx, http.MethodGet, l.lro.URL()) - if err != nil { - return nil, err - } - resp, err := l.pl.Do(req) - if err != nil { - // don't update the poller for failed requests - return nil, err - } - if !lroStatusCodeValid(resp) { - // the LRO failed. unmarshall the error and update state - l.err = l.eu(resp) - l.resp = nil - return nil, l.err - } - if err = l.lro.Update(resp); err != nil { - return nil, err - } - l.resp = resp - return l.resp, nil -} - -// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. -func (l *Poller) ResumeToken() (string, error) { - if l.Done() { - return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") - } - b, err := json.Marshal(l.lro) - if err != nil { - return "", err - } - return string(b), nil -} - -// FinalResponse will perform a final GET request and return the final HTTP response for the polling -// operation and unmarshall the content of the payload into the respType interface that is provided. -func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { - if !l.Done() { - return nil, errors.New("cannot return a final response from a poller in a non-terminal state") - } - // if there's nothing to unmarshall into just return the final response - if respType == nil { - return l.resp, nil - } - u, err := l.lro.FinalGetURL(l.resp) - if err != nil { - return nil, err - } - if u != "" { - req, err := NewRequest(ctx, http.MethodGet, u) - if err != nil { - return nil, err - } - resp, err := l.pl.Do(req) - if err != nil { - return nil, err - } - if !lroStatusCodeValid(resp) { - return nil, l.eu(resp) - } - l.resp = resp - } - body, err := ioutil.ReadAll(l.resp.Body) - l.resp.Body.Close() - if err != nil { - return nil, err - } - if err = json.Unmarshal(body, respType); err != nil { - return nil, err - } - return l.resp, nil -} - -// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, -// then return the final HTTP response for the polling operation and unmarshal the content of the payload -// into the respType interface that is provided. -func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { - logPollUntilDoneExit := func(v interface{}) { - log.Writef(log.LongRunningOperation, "END PollUntilDone() for %T: %v", l.lro, v) - } - log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) - if l.resp != nil { - // initial check for a retry-after header existing on the initial response - if retryAfter := RetryAfter(l.resp); retryAfter > 0 { - log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) - if err := delay(ctx, retryAfter); err != nil { - logPollUntilDoneExit(err) - return nil, err - } - } - } - // begin polling the endpoint until a terminal state is reached - for { - resp, err := l.Poll(ctx) - if err != nil { - logPollUntilDoneExit(err) - return nil, err - } - if l.Done() { - logPollUntilDoneExit(l.lro.Status()) - if !l.lro.Succeeded() { - return nil, l.eu(resp) - } - return l.FinalResponse(ctx, respType) - } - d := freq - if retryAfter := RetryAfter(resp); retryAfter > 0 { - log.Writef(log.LongRunningOperation, "Retry-After delay for %s", retryAfter.String()) - d = retryAfter - } else { - log.Writef(log.LongRunningOperation, "delay for %s", d.String()) - } - if err = delay(ctx, d); err != nil { - logPollUntilDoneExit(err) - return nil, err - } - } -} - -// abstracts the differences between concrete poller types -type lroPoller interface { - Done() bool - Update(resp *http.Response) error - FinalGetURL(resp *http.Response) (string, error) - URL() string - Status() string - Succeeded() bool -} - -// ==================================================================================================== - -// polls on the operation-location header -type opPoller struct { - Type string `json:"type"` - ReqMethod string `json:"reqMethod"` - ReqURL string `json:"reqURL"` - PollURL string `json:"pollURL"` - LocURL string `json:"locURL"` - status string -} - -func newOpPoller(pollerType, pollingURL, locationURL string, initialResponse *http.Response) *opPoller { - return &opPoller{ - Type: fmt.Sprintf("%s;opPoller", pollerType), - ReqMethod: initialResponse.Request.Method, - ReqURL: initialResponse.Request.URL.String(), - PollURL: pollingURL, - LocURL: locationURL, - } -} - -func (p *opPoller) URL() string { - return p.PollURL -} - -func (p *opPoller) Done() bool { - return strings.EqualFold(p.status, "succeeded") || - strings.EqualFold(p.status, "failed") || - strings.EqualFold(p.status, "cancelled") -} - -func (p *opPoller) Succeeded() bool { - return strings.EqualFold(p.status, "succeeded") -} - -func (p *opPoller) Update(resp *http.Response) error { - status, err := extractJSONValue(resp, "status") - if err != nil { - return err - } - if status == "" { - return errors.New("no status found in body") - } - p.status = status - // if the endpoint returned an operation-location header, update cached value - if opLoc := resp.Header.Get(headerOperationLocation); opLoc != "" { - p.PollURL = opLoc - } - return nil -} - -func (p *opPoller) FinalGetURL(resp *http.Response) (string, error) { - if !p.Done() { - return "", errors.New("cannot return a final response from a poller in a non-terminal state") - } - resLoc, err := extractJSONValue(resp, "resourceLocation") - if err != nil { - return "", err - } - if resLoc != "" { - return resLoc, nil - } - if p.ReqMethod == http.MethodPatch || p.ReqMethod == http.MethodPut { - return p.ReqURL, nil - } - if p.ReqMethod == http.MethodPost && p.LocURL != "" { - return p.LocURL, nil - } - return "", nil -} - -func (p *opPoller) Status() string { - return p.status -} - -// ==================================================================================================== - -// polls on the location header -type locPoller struct { - Type string `json:"type"` - PollURL string `json:"pollURL"` - status int -} - -func newLocPoller(pollerType, pollingURL string, initialStatus int) *locPoller { - return &locPoller{ - Type: fmt.Sprintf("%s;locPoller", pollerType), - PollURL: pollingURL, - status: initialStatus, - } -} - -func (p *locPoller) URL() string { - return p.PollURL -} - -func (p *locPoller) Done() bool { - // a 202 means the operation is still in progress - // zero-value indicates the poller was rehydrated from a token - return p.status > 0 && p.status != http.StatusAccepted -} - -func (p *locPoller) Succeeded() bool { - // any 2xx status code indicates success - return p.status >= 200 && p.status < 300 -} - -func (p *locPoller) Update(resp *http.Response) error { - // if the endpoint returned a location header, update cached value - if loc := resp.Header.Get(headerLocation); loc != "" { - p.PollURL = loc - } - p.status = resp.StatusCode - return nil -} - -func (*locPoller) FinalGetURL(*http.Response) (string, error) { - return "", nil -} - -func (p *locPoller) Status() string { - return strconv.Itoa(p.status) -} - -// ==================================================================================================== - -// used if the endpoint didn't return any polling headers (synchronous completion) -type nopPoller struct{} - -func (*nopPoller) URL() string { - return "" -} - -func (*nopPoller) Done() bool { - return true -} - -func (*nopPoller) Succeeded() bool { - return true -} - -func (*nopPoller) Update(*http.Response) error { - return nil -} - -func (*nopPoller) FinalGetURL(*http.Response) (string, error) { - return "", nil -} - -func (*nopPoller) Status() string { - return "succeeded" -} - -// returns true if the LRO response contains a valid HTTP status code -func lroStatusCodeValid(resp *http.Response) bool { - return HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) -} - -// extracs a JSON value from the provided reader -func extractJSONValue(resp *http.Response, val string) (string, error) { - if resp.ContentLength == 0 { - return "", errors.New("the response does not contain a body") - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - // put the body back so it's available to our callers - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - // unmarshall the body to get the value - var jsonBody map[string]interface{} - if err = json.Unmarshal(body, &jsonBody); err != nil { - return "", err - } - v, ok := jsonBody[val] - if !ok { - // it might be ok if the field doesn't exist, the caller must make that determination - return "", nil - } - vv, ok := v.(string) - if !ok { - return "", fmt.Errorf("the %s value %v was not in string format", val, v) - } - return vv, nil -} - -func delay(ctx context.Context, delay time.Duration) error { - select { - case <-time.After(delay): - return nil - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/sdk/azcore/request.go b/sdk/azcore/request.go deleted file mode 100644 index a3074596a2f8..000000000000 --- a/sdk/azcore/request.go +++ /dev/null @@ -1,398 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "io/ioutil" - "mime/multipart" - "net/http" - "reflect" - "strconv" - "strings" - - "golang.org/x/net/http/httpguts" -) - -const ( - contentTypeAppJSON = "application/json" - contentTypeAppXML = "application/xml" -) - -// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when -// encoding/decoding a slice of bytes to/from a string. -type Base64Encoding int - -const ( - // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. - Base64StdFormat Base64Encoding = 0 - - // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. - Base64URLFormat Base64Encoding = 1 -) - -// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. -// Don't use this type directly, use NewRequest() instead. -type Request struct { - *http.Request - body ReadSeekCloser - policies []Policy - values opValues -} - -type opValues map[reflect.Type]interface{} - -// Set adds/changes a value -func (ov opValues) set(value interface{}) { - ov[reflect.TypeOf(value)] = value -} - -// Get looks for a value set by SetValue first -func (ov opValues) get(value interface{}) bool { - v, ok := ov[reflect.ValueOf(value).Elem().Type()] - if ok { - reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) - } - return ok -} - -// JoinPaths concatenates multiple URL path segments into one path, -// inserting path separation characters as required. JoinPaths will preserve -// query parameters in the root path -func JoinPaths(root string, paths ...string) string { - if len(paths) == 0 { - return root - } - - qps := "" - if strings.Contains(root, "?") { - splitPath := strings.Split(root, "?") - root, qps = splitPath[0], splitPath[1] - } - - for i := 0; i < len(paths); i++ { - root = strings.TrimRight(root, "/") - paths[i] = strings.TrimLeft(paths[i], "/") - root += "/" + paths[i] - } - - if qps != "" { - if !strings.HasSuffix(root, "/") { - root += "/" - } - return root + "?" + qps - } - return root -} - -// NewRequest creates a new Request with the specified input. -func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { - req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) - if err != nil { - return nil, err - } - if req.URL.Host == "" { - return nil, errors.New("no Host in request URL") - } - if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { - return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) - } - return &Request{Request: req}, nil -} - -// Next calls the next policy in the pipeline. -// If there are no more policies, nil and ErrNoMorePolicies are returned. -// This method is intended to be called from pipeline policies. -// To send a request through a pipeline call Pipeline.Do(). -func (req *Request) Next() (*http.Response, error) { - if len(req.policies) == 0 { - return nil, errors.New("no more policies") - } - nextPolicy := req.policies[0] - nextReq := *req - nextReq.policies = nextReq.policies[1:] - return nextPolicy.Do(&nextReq) -} - -// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. -// The encoded value is treated as a JSON string. -func (req *Request) MarshalAsByteArray(v []byte, format Base64Encoding) error { - // send as a JSON string - encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) - return req.SetBody(NopCloser(strings.NewReader(encode)), contentTypeAppJSON) -} - -// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. -func (req *Request) MarshalAsJSON(v interface{}) error { - v = cloneWithoutReadOnlyFields(v) - b, err := json.Marshal(v) - if err != nil { - return fmt.Errorf("error marshalling type %T: %s", v, err) - } - return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppJSON) -} - -// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. -func (req *Request) MarshalAsXML(v interface{}) error { - b, err := xml.Marshal(v) - if err != nil { - return fmt.Errorf("error marshalling type %T: %s", v, err) - } - return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppXML) -} - -// SetOperationValue adds/changes a mutable key/value associated with a single operation. -func (req *Request) SetOperationValue(value interface{}) { - if req.values == nil { - req.values = opValues{} - } - req.values.set(value) -} - -// OperationValue looks for a value set by SetOperationValue(). -func (req *Request) OperationValue(value interface{}) bool { - if req.values == nil { - return false - } - return req.values.get(value) -} - -// SetBody sets the specified ReadSeekCloser as the HTTP request body. -func (req *Request) SetBody(body ReadSeekCloser, contentType string) error { - // Set the body and content length. - size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size - if err != nil { - return err - } - if size == 0 { - body.Close() - return nil - } - _, err = body.Seek(0, io.SeekStart) - if err != nil { - return err - } - // keep a copy of the original body. this is to handle cases - // where req.Body is replaced, e.g. httputil.DumpRequest and friends. - req.body = body - req.Request.Body = body - req.Request.ContentLength = size - req.Header.Set(headerContentType, contentType) - req.Header.Set(headerContentLength, strconv.FormatInt(size, 10)) - return nil -} - -// SetMultipartFormData writes the specified keys/values as multi-part form -// fields with the specified value. File content must be specified as a ReadSeekCloser. -// All other values are treated as string values. -func (req *Request) SetMultipartFormData(formData map[string]interface{}) error { - body := bytes.Buffer{} - writer := multipart.NewWriter(&body) - for k, v := range formData { - if rsc, ok := v.(ReadSeekCloser); ok { - // this is the body to upload, the key is its file name - fd, err := writer.CreateFormFile(k, k) - if err != nil { - return err - } - // copy the data to the form file - if _, err = io.Copy(fd, rsc); err != nil { - return err - } - continue - } - // ensure the value is in string format - s, ok := v.(string) - if !ok { - s = fmt.Sprintf("%v", v) - } - if err := writer.WriteField(k, s); err != nil { - return err - } - } - if err := writer.Close(); err != nil { - return err - } - req.body = NopCloser(bytes.NewReader(body.Bytes())) - req.Body = req.body - req.ContentLength = int64(body.Len()) - req.Header.Set(headerContentType, writer.FormDataContentType()) - req.Header.Set(headerContentLength, strconv.FormatInt(req.ContentLength, 10)) - return nil -} - -// SkipBodyDownload will disable automatic downloading of the response body. -func (req *Request) SkipBodyDownload() { - req.SetOperationValue(bodyDownloadPolicyOpValues{skip: true}) -} - -// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. -func (req *Request) RewindBody() error { - if req.body != nil { - // Reset the stream back to the beginning and restore the body - _, err := req.body.Seek(0, io.SeekStart) - req.Body = req.body - return err - } - return nil -} - -// Close closes the request body. -func (req *Request) Close() error { - if req.Body == nil { - return nil - } - return req.Body.Close() -} - -// Telemetry adds telemetry data to the request. -// If telemetry reporting is disabled the value is discarded. -func (req *Request) Telemetry(v string) { - req.SetOperationValue(requestTelemetry(v)) -} - -// clone returns a deep copy of the request with its context changed to ctx -func (req *Request) clone(ctx context.Context) *Request { - r2 := Request{} - r2 = *req - r2.Request = req.Request.Clone(ctx) - return &r2 -} - -// valid returns nil if the underlying http.Request is well-formed. -func (req *Request) valid() error { - // check copied from Transport.roundTrip() - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - req.Close() - return fmt.Errorf("invalid header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - req.Close() - return fmt.Errorf("invalid header field value %q for key %v", v, k) - } - } - } - return nil -} - -// writes to a buffer, used for logging purposes -func (req *Request) writeBody(b *bytes.Buffer) error { - if req.Body == nil { - fmt.Fprint(b, " Request contained no body\n") - return nil - } - if ct := req.Header.Get(headerContentType); !shouldLogBody(b, ct) { - return nil - } - body, err := ioutil.ReadAll(req.Body) - if err != nil { - fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) - return err - } - if err := req.RewindBody(); err != nil { - return err - } - logBody(b, body) - return nil -} - -// EncodeByteArray will base-64 encode the byte slice v. -func EncodeByteArray(v []byte, format Base64Encoding) string { - if format == Base64URLFormat { - return base64.RawURLEncoding.EncodeToString(v) - } - return base64.StdEncoding.EncodeToString(v) -} - -// returns a clone of the object graph pointed to by v, omitting values of all read-only -// fields. if there are no read-only fields in the object graph, no clone is created. -func cloneWithoutReadOnlyFields(v interface{}) interface{} { - val := reflect.Indirect(reflect.ValueOf(v)) - if val.Kind() != reflect.Struct { - // not a struct, skip - return v - } - // first walk the graph to find any R/O fields. - // if there aren't any, skip cloning the graph. - if !recursiveFindReadOnlyField(val) { - return v - } - return recursiveCloneWithoutReadOnlyFields(val) -} - -// returns true if any field in the object graph of val contains the `azure:"ro"` tag value -func recursiveFindReadOnlyField(val reflect.Value) bool { - t := val.Type() - // iterate over the fields, looking for the "azure" tag. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - aztag := field.Tag.Get("azure") - if azureTagIsReadOnly(aztag) { - return true - } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) { - return true - } - } - return false -} - -// clones the object graph of val. all non-R/O properties are copied to the clone -func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} { - clone := reflect.New(val.Type()) - t := val.Type() - // iterate over the fields, looking for the "azure" tag. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - aztag := field.Tag.Get("azure") - if azureTagIsReadOnly(aztag) { - // omit from payload - } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct { - // recursive case - v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i))) - if t.Field(i).Anonymous { - // NOTE: this does not handle the case of embedded fields of unexported struct types. - // this should be ok as we don't generate any code like this at present - reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v))) - } else { - reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v)) - } - } else { - // no azure RO tag, non-recursive case, include in payload - reflect.Indirect(clone).Field(i).Set(val.Field(i)) - } - } - return clone.Interface() -} - -// returns true if the "azure" tag contains the option "ro" -func azureTagIsReadOnly(tag string) bool { - if tag == "" { - return false - } - parts := strings.Split(tag, ",") - for _, part := range parts { - if part == "ro" { - return true - } - } - return false -} - -func logBody(b *bytes.Buffer, body []byte) { - fmt.Fprintln(b, " --------------------------------------------------------------------------------") - fmt.Fprintln(b, string(body)) - fmt.Fprintln(b, " --------------------------------------------------------------------------------") -} diff --git a/sdk/azcore/runtime/errors.go b/sdk/azcore/runtime/errors.go new file mode 100644 index 000000000000..badf62a3bd37 --- /dev/null +++ b/sdk/azcore/runtime/errors.go @@ -0,0 +1,21 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// NewResponseError wraps the specified error with an error that provides access to an HTTP response. +// If an HTTP request returns a non-successful status code, wrap the response and the associated error +// in this error type so that callers can access the underlying *http.Response as required. +// DO NOT wrap failed HTTP requests that returned an error and no response with this type. +func NewResponseError(inner error, resp *http.Response) error { + return shared.NewResponseError(inner, resp) +} diff --git a/sdk/azcore/policy_body_download.go b/sdk/azcore/runtime/policy_body_download.go similarity index 80% rename from sdk/azcore/policy_body_download.go rename to sdk/azcore/runtime/policy_body_download.go index 8724b11f10b8..7fca2ba044d5 100644 --- a/sdk/azcore/policy_body_download.go +++ b/sdk/azcore/runtime/policy_body_download.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "errors" @@ -13,17 +13,21 @@ import ( "io/ioutil" "net/http" "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" ) // bodyDownloadPolicy creates a policy object that downloads the response's body to a []byte. -func bodyDownloadPolicy(req *Request) (*http.Response, error) { +func bodyDownloadPolicy(req *policy.Request) (*http.Response, error) { resp, err := req.Next() if err != nil { return resp, err } - var opValues bodyDownloadPolicyOpValues + var opValues shared.BodyDownloadPolicyOpValues // don't skip downloading error response bodies - if req.OperationValue(&opValues); opValues.skip && resp.StatusCode < 400 { + if req.OperationValue(&opValues); opValues.Skip && resp.StatusCode < 400 { return resp, err } // Either bodyDownloadPolicyOpValues was not specified (so skip is false) @@ -41,10 +45,10 @@ type bodyDownloadError struct { err error } -func newBodyDownloadError(err error, req *Request) error { +func newBodyDownloadError(err error, req *policy.Request) error { // on failure, only retry the request for idempotent operations. // we currently identify them as DELETE, GET, and PUT requests. - if m := strings.ToUpper(req.Method); m == http.MethodDelete || m == http.MethodGet || m == http.MethodPut { + if m := strings.ToUpper(req.Raw().Method); m == http.MethodDelete || m == http.MethodGet || m == http.MethodPut { // error is safe for retry return err } @@ -66,12 +70,7 @@ func (b *bodyDownloadError) Unwrap() error { return b.err } -var _ NonRetriableError = (*bodyDownloadError)(nil) - -// bodyDownloadPolicyOpValues is the struct containing the per-operation values -type bodyDownloadPolicyOpValues struct { - skip bool -} +var _ errorinfo.NonRetriable = (*bodyDownloadError)(nil) // nopClosingBytesReader is an io.ReadSeekCloser around a byte slice. // It also provides direct access to the byte slice. diff --git a/sdk/azcore/policy_body_download_test.go b/sdk/azcore/runtime/policy_body_download_test.go similarity index 99% rename from sdk/azcore/policy_body_download_test.go rename to sdk/azcore/runtime/policy_body_download_test.go index a7387cbedd4f..e4e70fa42779 100644 --- a/sdk/azcore/policy_body_download_test.go +++ b/sdk/azcore/runtime/policy_body_download_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" diff --git a/sdk/azcore/runtime/policy_http_header.go b/sdk/azcore/runtime/policy_http_header.go new file mode 100644 index 000000000000..148c6d9a313d --- /dev/null +++ b/sdk/azcore/runtime/policy_http_header.go @@ -0,0 +1,31 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request +func httpHeaderPolicy(req *policy.Request) (*http.Response, error) { + // check if any custom HTTP headers have been specified + if header := req.Raw().Context().Value(shared.CtxWithHTTPHeaderKey{}); header != nil { + for k, v := range header.(http.Header) { + // use Set to replace any existing value + // it also canonicalizes the header key + req.Raw().Header.Set(k, v[0]) + // add any remaining values + for i := 1; i < len(v); i++ { + req.Raw().Header.Add(k, v[i]) + } + } + } + return req.Next() +} diff --git a/sdk/azcore/policy_http_header_test.go b/sdk/azcore/runtime/policy_http_header_test.go similarity index 91% rename from sdk/azcore/policy_http_header_test.go rename to sdk/azcore/runtime/policy_http_header_test.go index ededd5955d17..2e9aef013492 100644 --- a/sdk/azcore/policy_http_header_test.go +++ b/sdk/azcore/runtime/policy_http_header_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,6 +12,7 @@ import ( "net/textproto" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -29,13 +30,13 @@ func TestAddCustomHTTPHeaderSuccess(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest)) // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue}, }), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Header.Set(preexistingHeader, preexistingValue) + req.Raw().Header.Set(preexistingHeader, preexistingValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -81,7 +82,7 @@ func TestAddCustomHTTPHeaderOverwrite(t *testing.T) { // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) // overwrite the request ID with our own value - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue}, }), http.MethodGet, srv.URL()) if err != nil { @@ -112,7 +113,7 @@ func TestAddCustomHTTPHeaderMultipleValues(t *testing.T) { // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) // overwrite the request ID with our own value - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue1, customValue2}, }), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/policy_logging.go b/sdk/azcore/runtime/policy_logging.go similarity index 60% rename from sdk/azcore/policy_logging.go rename to sdk/azcore/runtime/policy_logging.go index b97189f3757b..b013a0c45f1a 100644 --- a/sdk/azcore/policy_logging.go +++ b/sdk/azcore/runtime/policy_logging.go @@ -4,36 +4,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" "fmt" + "io/ioutil" "net/http" "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/diag" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) -// LogOptions configures the logging policy's behavior. -type LogOptions struct { - // IncludeBody indicates if request and response bodies should be included in logging. - // The default value is false. - // NOTE: enabling this can lead to disclosure of sensitive information, use with care. - IncludeBody bool -} - type logPolicy struct { - options LogOptions + options policy.LogOptions } // NewLogPolicy creates a RequestLogPolicy object configured using the specified options. // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewLogPolicy(o *LogOptions) Policy { +func NewLogPolicy(o *policy.LogOptions) policy.Policy { if o == nil { - o = &LogOptions{} + o = &policy.LogOptions{} } return &logPolicy{options: *o} } @@ -44,7 +39,7 @@ type logPolicyOpValues struct { start time.Time } -func (p *logPolicy) Do(req *Request) (*http.Response, error) { +func (p *logPolicy) Do(req *policy.Request) (*http.Response, error) { // Get the per-operation values. These are saved in the Message's map so that they persist across each retry calling into this policy object. var opValues logPolicyOpValues if req.OperationValue(&opValues); opValues.start.IsZero() { @@ -60,7 +55,7 @@ func (p *logPolicy) Do(req *Request) (*http.Response, error) { writeRequestWithResponse(b, req, nil, nil) var err error if p.options.IncludeBody { - err = req.writeBody(b) + err = writeReqBody(req, b) } log.Write(log.Request, b.String()) if err != nil { @@ -88,9 +83,9 @@ func (p *logPolicy) Do(req *Request) (*http.Response, error) { writeRequestWithResponse(b, req, response, err) if err != nil { // skip frames runtime.Callers() and runtime.StackTrace() - b.WriteString(diag.StackTrace(2, StackFrameCount)) + b.WriteString(diag.StackTrace(2, 32)) } else if p.options.IncludeBody { - err = writeBody(response, b) + err = writeRespBody(response, b) } log.Write(log.Response, b.String()) } @@ -109,3 +104,52 @@ func shouldLogBody(b *bytes.Buffer, contentType string) bool { fmt.Fprintf(b, " Skip logging body for %s\n", contentType) return false } + +// writes to a buffer, used for logging purposes +func writeReqBody(req *policy.Request, b *bytes.Buffer) error { + if req.Raw().Body == nil { + fmt.Fprint(b, " Request contained no body\n") + return nil + } + if ct := req.Raw().Header.Get(shared.HeaderContentType); !shouldLogBody(b, ct) { + return nil + } + body, err := ioutil.ReadAll(req.Raw().Body) + if err != nil { + fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) + return err + } + if err := req.RewindBody(); err != nil { + return err + } + logBody(b, body) + return nil +} + +// writes to a buffer, used for logging purposes +func writeRespBody(resp *http.Response, b *bytes.Buffer) error { + ct := resp.Header.Get(shared.HeaderContentType) + if ct == "" { + fmt.Fprint(b, " Response contained no body\n") + return nil + } else if !shouldLogBody(b, ct) { + return nil + } + body, err := Payload(resp) + if err != nil { + fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) + return err + } + if len(body) > 0 { + logBody(b, body) + } else { + fmt.Fprint(b, " Response contained no body\n") + } + return nil +} + +func logBody(b *bytes.Buffer, body []byte) { + fmt.Fprintln(b, " --------------------------------------------------------------------------------") + fmt.Fprintln(b, string(body)) + fmt.Fprintln(b, " --------------------------------------------------------------------------------") +} diff --git a/sdk/azcore/policy_logging_test.go b/sdk/azcore/runtime/policy_logging_test.go similarity index 95% rename from sdk/azcore/policy_logging_test.go rename to sdk/azcore/runtime/policy_logging_test.go index 6524ae978ce9..6ef960ce775f 100644 --- a/sdk/azcore/policy_logging_test.go +++ b/sdk/azcore/runtime/policy_logging_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -31,10 +31,10 @@ func TestPolicyLoggingSuccess(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - qp := req.URL.Query() + qp := req.Raw().URL.Query() qp.Set("one", "fish") qp.Set("sig", "redact") - req.URL.RawQuery = qp.Encode() + req.Raw().URL.RawQuery = qp.Encode() resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -81,8 +81,8 @@ func TestPolicyLoggingError(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Header.Add("header", "one") - req.Header.Add("Authorization", "redact") + req.Raw().Header.Add("header", "one") + req.Raw().Header.Add("Authorization", "redact") resp, err := pl.Do(req) if err == nil { t.Fatal("unexpected nil error") diff --git a/sdk/azcore/policy_retry.go b/sdk/azcore/runtime/policy_retry.go similarity index 60% rename from sdk/azcore/policy_retry.go rename to sdk/azcore/runtime/policy_retry.go index b916f28fbadd..55eedd1a5652 100644 --- a/sdk/azcore/policy_retry.go +++ b/sdk/azcore/runtime/policy_retry.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -15,46 +15,15 @@ import ( "net/http" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) -const ( - defaultMaxRetries = 3 -) - -// RetryOptions configures the retry policy's behavior. -// All zero-value fields will be initialized with their default values. -type RetryOptions struct { - // MaxRetries specifies the maximum number of attempts a failed operation will be retried - // before producing an error. - // The default value is three. A value less than zero means one try and no retries. - MaxRetries int32 - - // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. - // This is disabled by default. Specify a value greater than zero to enable. - // NOTE: Setting this to a small value might cause premature HTTP request time-outs. - TryTimeout time.Duration - - // RetryDelay specifies the initial amount of delay to use before retrying an operation. - // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. - // The default value is four seconds. A value less than zero means no delay between retries. - RetryDelay time.Duration - - // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. - // Typically the value is greater than or equal to the value specified in RetryDelay. - // The default Value is 120 seconds. A value less than zero means there is no cap. - MaxRetryDelay time.Duration - - // StatusCodes specifies the HTTP status codes that indicate the operation should be retried. - // The default value is the status codes in StatusCodesForRetry. - // Specifying an empty slice will cause retries to happen only for transport errors. - StatusCodes []int -} - -// init sets any default values -func (o *RetryOptions) init() { +func setDefaults(o *policy.RetryOptions) { if o.MaxRetries == 0 { - o.MaxRetries = defaultMaxRetries + o.MaxRetries = shared.DefaultMaxRetries } else if o.MaxRetries < 0 { o.MaxRetries = 0 } @@ -80,17 +49,7 @@ func (o *RetryOptions) init() { } } -// used as a context key for adding/retrieving RetryOptions -type ctxWithRetryOptionsKey struct{} - -// WithRetryOptions adds the specified RetryOptions to the parent context. -// Use this to specify custom RetryOptions at the API-call level. -func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { - options.init() - return context.WithValue(parent, ctxWithRetryOptionsKey{}, options) -} - -func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0 +func calcDelay(o policy.RetryOptions, try int32) time.Duration { // try is >=1; never 0 pow := func(number int64, exponent int32) int64 { // pow is nested helper function var result int64 = 1 for n := int32(0); n < exponent; n++ { @@ -111,42 +70,38 @@ func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never // NewRetryPolicy creates a policy object configured using the specified options. // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewRetryPolicy(o *RetryOptions) Policy { +func NewRetryPolicy(o *policy.RetryOptions) policy.Policy { if o == nil { - o = &RetryOptions{} + o = &policy.RetryOptions{} } p := &retryPolicy{options: *o} - // fix up values in the copy - p.options.init() return p } type retryPolicy struct { - options RetryOptions + options policy.RetryOptions } -func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { +func (p *retryPolicy) Do(req *policy.Request) (resp *http.Response, err error) { options := p.options // check if the retry options have been overridden for this call - if override := req.Context().Value(ctxWithRetryOptionsKey{}); override != nil { - options = override.(RetryOptions) + if override := req.Raw().Context().Value(shared.CtxWithRetryOptionsKey{}); override != nil { + options = override.(policy.RetryOptions) } + setDefaults(&options) // Exponential retry algorithm: ((2 ^ attempt) - 1) * delay * random(0.8, 1.2) // When to retry: connection failure or temporary/timeout. - if req.body != nil { - // wrap the body so we control when it's actually closed - rwbody := &retryableRequestBody{body: req.body} - req.body = rwbody - req.Request.GetBody = func() (io.ReadCloser, error) { - _, err := rwbody.Seek(0, io.SeekStart) // Seek back to the beginning of the stream - return rwbody, err - } + var rwbody *retryableRequestBody + if req.Body() != nil { + // wrap the body so we control when it's actually closed. + // do this outside the for loop so defers don't accumulate. + rwbody = &retryableRequestBody{body: req.Body()} defer rwbody.realClose() } try := int32(1) for { resp = nil // reset - log.Writef(log.RetryPolicy, "\n=====> Try=%d %s %s", try, req.Method, req.URL.String()) + log.Writef(log.RetryPolicy, "\n=====> Try=%d %s %s", try, req.Raw().Method, req.Raw().URL.String()) // For each try, seek to the beginning of the Body stream. We do this even for the 1st try because // the stream may not be at offset 0 when we first get it and we want the same behavior for the @@ -155,13 +110,17 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { if err != nil { return } + // RewindBody() restores Raw().Body to its original state, so set our rewindable after + if rwbody != nil { + req.Raw().Body = rwbody + } if options.TryTimeout == 0 { resp, err = req.Next() } else { // Set the per-try time for this particular retry operation and then Do the operation. - tryCtx, tryCancel := context.WithTimeout(req.Context(), options.TryTimeout) - clone := req.clone(tryCtx) + tryCtx, tryCancel := context.WithTimeout(req.Raw().Context(), options.TryTimeout) + clone := req.Clone(tryCtx) resp, err = clone.Next() // Make the request tryCancel() } @@ -174,7 +133,7 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { if err == nil && !HasStatusCode(resp, options.StatusCodes...) { // if there is no error and the response code isn't in the list of retry codes then we're done. return - } else if ctxErr := req.Context().Err(); ctxErr != nil { + } else if ctxErr := req.Raw().Context().Err(); ctxErr != nil { // don't retry if the parent context has been cancelled or its deadline exceeded err = ctxErr log.Writef(log.RetryPolicy, "abort due to %v", err) @@ -182,7 +141,7 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { } // check if the error is not retriable - var nre NonRetriableError + var nre errorinfo.NonRetriable if errors.As(err, &nre) { // the error says it's not retriable so don't retry log.Writef(log.RetryPolicy, "non-retriable error %T", nre) @@ -199,16 +158,16 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { Drain(resp) // use the delay from retry-after if available - delay := RetryAfter(resp) + delay := shared.RetryAfter(resp) if delay <= 0 { - delay = options.calcDelay(try) + delay = calcDelay(options, try) } log.Writef(log.RetryPolicy, "End Try #%d, Delay=%v", try, delay) select { case <-time.After(delay): try++ - case <-req.Context().Done(): - err = req.Context().Err() + case <-req.Raw().Context().Done(): + err = req.Raw().Context().Err() log.Writef(log.RetryPolicy, "abort due to %v", err) return } diff --git a/sdk/azcore/policy_retry_test.go b/sdk/azcore/runtime/policy_retry_test.go similarity index 90% rename from sdk/azcore/policy_retry_test.go rename to sdk/azcore/runtime/policy_retry_test.go index f7ab18e4c556..89d46c3ac260 100644 --- a/sdk/azcore/policy_retry_test.go +++ b/sdk/azcore/runtime/policy_retry_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -17,13 +17,17 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) -func testRetryOptions() *RetryOptions { - def := RetryOptions{} - def.RetryDelay = 20 * time.Millisecond - return &def +func testRetryOptions() *policy.RetryOptions { + return &policy.RetryOptions{ + RetryDelay: 20 * time.Millisecond, + } } func TestRetryPolicySuccess(t *testing.T) { @@ -74,10 +78,10 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -92,12 +96,12 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError), mock.WithBody([]byte(respBody))) // add a per-request policy that reads and restores the request body. // this is to simulate how something like httputil.DumpRequest works. - pl := NewPipeline(srv, policyFunc(func(r *Request) (*http.Response, error) { - b, err := ioutil.ReadAll(r.Body) + pl := NewPipeline(srv, pipeline.PolicyFunc(func(r *policy.Request) (*http.Response, error) { + b, err := ioutil.ReadAll(r.Raw().Body) if err != nil { t.Fatal(err) } - r.Body = ioutil.NopCloser(bytes.NewReader(b)) + r.Raw().Body = ioutil.NopCloser(bytes.NewReader(b)) return r.Next() }), NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) @@ -115,10 +119,10 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -210,7 +214,7 @@ func TestRetryPolicyNoRetries(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse() - pl := NewPipeline(srv, NewRetryPolicy(&RetryOptions{MaxRetries: -1})) + pl := NewPipeline(srv, NewRetryPolicy(&policy.RetryOptions{MaxRetries: -1})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -273,10 +277,10 @@ func TestRetryPolicyFailOnError(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -307,10 +311,10 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { if resp.StatusCode != http.StatusAccepted { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -360,7 +364,7 @@ func (f fatalError) NonRetriable() { // marker method } -var _ NonRetriableError = (*fatalError)(nil) +var _ errorinfo.NonRetriable = (*fatalError)(nil) func TestRetryPolicyIsNotRetriable(t *testing.T) { theErr := fatalError{s: "it's dead Jim"} @@ -395,7 +399,7 @@ func TestWithRetryOptions(t *testing.T) { customOptions := *defaultOptions customOptions.MaxRetries = 10 customOptions.MaxRetryDelay = 200 * time.Millisecond - retryCtx := WithRetryOptions(context.Background(), customOptions) + retryCtx := policy.WithRetryOptions(context.Background(), customOptions) req, err := NewRequest(retryCtx, http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -437,8 +441,8 @@ func TestRetryPolicyFailOnErrorNoDownload(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } } @@ -613,7 +617,7 @@ func (r *rewindTrackingBody) Seek(offset int64, whence int) (int64, error) { // used to inject a nil response type nilRespInjector struct { - t Transporter + t policy.Transporter c int // the current request number r []int // the list of request numbers to return a nil response (one-based) } diff --git a/sdk/azcore/policy_telemetry.go b/sdk/azcore/runtime/policy_telemetry.go similarity index 54% rename from sdk/azcore/policy_telemetry.go rename to sdk/azcore/runtime/policy_telemetry.go index a3b0bc09eace..5e628e7a3257 100644 --- a/sdk/azcore/policy_telemetry.go +++ b/sdk/azcore/runtime/policy_telemetry.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -13,32 +13,21 @@ import ( "os" "runtime" "strings" -) - -// TelemetryOptions configures the telemetry policy's behavior. -type TelemetryOptions struct { - // Value is a string prepended to each request's User-Agent and sent to the service. - // The service records the user-agent in logs for diagnostics and tracking of client requests. - Value string - - // ApplicationID is an application-specific identification string used in telemetry. - // It has a maximum length of 24 characters and must not contain any spaces. - ApplicationID string - // Disabled will prevent the addition of any telemetry data to the User-Agent. - Disabled bool -} + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) type telemetryPolicy struct { telemetryValue string } // NewTelemetryPolicy creates a telemetry policy object that adds telemetry information to outgoing HTTP requests. -// The format is [ ]azsdk--/ []. +// The format is [ ]azsdk-go-/ . // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewTelemetryPolicy(o *TelemetryOptions) Policy { +func NewTelemetryPolicy(mod, ver string, o *policy.TelemetryOptions) policy.Policy { if o == nil { - o = &TelemetryOptions{} + o = &policy.TelemetryOptions{} } tp := telemetryPolicy{} if o.Disabled { @@ -54,31 +43,29 @@ func NewTelemetryPolicy(o *TelemetryOptions) Policy { b.WriteString(o.ApplicationID) b.WriteRune(' ') } - // write out telemetry string - if o.Value != "" { - b.WriteString(o.Value) - b.WriteRune(' ') - } - b.WriteString(UserAgent) + b.WriteString(formatTelemetry(mod, ver)) + b.WriteRune(' ') + // inject azcore info + b.WriteString(formatTelemetry(shared.Module, shared.Version)) b.WriteRune(' ') b.WriteString(platformInfo) tp.telemetryValue = b.String() return &tp } -func (p telemetryPolicy) Do(req *Request) (*http.Response, error) { +func formatTelemetry(comp, ver string) string { + return fmt.Sprintf("azsdk-go-%s/%s", comp, ver) +} + +func (p telemetryPolicy) Do(req *policy.Request) (*http.Response, error) { if p.telemetryValue == "" { return req.Next() } // preserve the existing User-Agent string - if ua := req.Request.Header.Get(headerUserAgent); ua != "" { + if ua := req.Raw().Header.Get(shared.HeaderUserAgent); ua != "" { p.telemetryValue = fmt.Sprintf("%s %s", p.telemetryValue, ua) } - var rt requestTelemetry - if req.OperationValue(&rt) { - p.telemetryValue = fmt.Sprintf("%s %s", string(rt), p.telemetryValue) - } - req.Request.Header.Set(headerUserAgent, p.telemetryValue) + req.Raw().Header.Set(shared.HeaderUserAgent, p.telemetryValue) return req.Next() } @@ -93,6 +80,3 @@ var platformInfo = func() string { } return fmt.Sprintf("(%s; %s)", runtime.Version(), operatingSystem) }() - -// used for adding per-request telemetry -type requestTelemetry string diff --git a/sdk/azcore/policy_telemetry_test.go b/sdk/azcore/runtime/policy_telemetry_test.go similarity index 52% rename from sdk/azcore/policy_telemetry_test.go rename to sdk/azcore/runtime/policy_telemetry_test.go index 6149eadae55f..03ea6949615e 100644 --- a/sdk/azcore/policy_telemetry_test.go +++ b/sdk/azcore/runtime/policy_telemetry_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,16 +12,18 @@ import ( "net/http" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) -var defaultTelemetry = UserAgent + " " + platformInfo +var defaultTelemetry = "azsdk-go-" + shared.Module + "/" + shared.Version + " " + platformInfo func TestPolicyTelemetryDefault(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := NewPipeline(srv, NewTelemetryPolicy(nil)) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -30,26 +32,7 @@ func TestPolicyTelemetryDefault(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != defaultTelemetry { - t.Fatalf("unexpected user agent value: %s", v) - } -} - -func TestPolicyTelemetryWithCustomInfo(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.SetResponse() - const testValue = "azcore_test" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{Value: testValue})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - resp, err := pl.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", testValue, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != "azsdk-go-test/v1.2.3 "+defaultTelemetry { t.Fatalf("unexpected user agent value: %s", v) } } @@ -58,18 +41,18 @@ func TestPolicyTelemetryPreserveExisting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := NewPipeline(srv, NewTelemetryPolicy(nil)) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } const otherValue = "this should stay" - req.Header.Set(headerUserAgent, otherValue) + req.Raw().Header.Set(shared.HeaderUserAgent, otherValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", defaultTelemetry, otherValue) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", "azsdk-go-test/v1.2.3 "+defaultTelemetry, otherValue) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -79,36 +62,16 @@ func TestPolicyTelemetryWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - resp, err := pl.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", appID, defaultTelemetry) { - t.Fatalf("unexpected user agent value: %s", v) - } -} - -func TestPolicyTelemetryWithAppIDAndReqTelemetry(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.SetResponse() - const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Telemetry("TestPolicyTelemetryWithAppIDAndReqTelemetry") resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s %s", "TestPolicyTelemetryWithAppIDAndReqTelemetry", appID, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", appID, "azsdk-go-test/v1.2.3 "+defaultTelemetry) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -118,7 +81,7 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) { defer close() srv.SetResponse() const appID = "This will get the spaces removed and truncated." - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -128,7 +91,7 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) { t.Fatalf("unexpected error: %v", err) } const newAppID = "This/will/get/the/spaces" - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", newAppID, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", newAppID, "azsdk-go-test/v1.2.3 "+defaultTelemetry) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -138,18 +101,18 @@ func TestPolicyTelemetryPreserveExistingWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } const otherValue = "this should stay" - req.Header.Set(headerUserAgent, otherValue) + req.Raw().Header.Set(shared.HeaderUserAgent, otherValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s %s", appID, defaultTelemetry, otherValue) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s %s", appID, "azsdk-go-test/v1.2.3 "+defaultTelemetry, otherValue) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -159,17 +122,16 @@ func TestPolicyTelemetryDisabled(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID, Disabled: true})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID, Disabled: true})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Telemetry("this should be ignored") resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != "" { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != "" { t.Fatalf("unexpected user agent value: %s", v) } } diff --git a/sdk/azcore/runtime/poller.go b/sdk/azcore/runtime/poller.go new file mode 100644 index 000000000000..686c04725da6 --- /dev/null +++ b/sdk/azcore/runtime/poller.go @@ -0,0 +1,70 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// NewPoller creates a Poller based on the provided initial response. +// pollerID - a unique identifier for an LRO, it's usually the client.Method string. +func NewPoller(pollerID string, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !pollers.StatusCodeValid(resp) { + return nil, errors.New("the operation failed or was cancelled") + } + // determine the polling method + var lro pollers.Operation + var err error + // op poller must be checked first as it can also have a location header + if op.Applicable(resp) { + lro, err = op.New(resp, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else { + lro = &pollers.NopPoller{} + } + if err != nil { + return nil, err + } + return pollers.NewPoller(lro, resp, pl, eu), nil +} + +// NewPollerFromResumeToken creates a Poller from a resume token string. +// pollerID - a unique identifier for an LRO, it's usually the client.Method string. +func NewPollerFromResumeToken(pollerID string, token string, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + kind, err := pollers.KindFromToken(pollerID, token) + if err != nil { + return nil, err + } + // now rehydrate the poller based on the encoded poller type + var lro pollers.Operation + switch kind { + case loc.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", loc.Kind) + lro = &loc.Poller{} + case op.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", op.Kind) + lro = &op.Poller{} + default: + return nil, fmt.Errorf("unhandled poller type %s", kind) + } + if err = json.Unmarshal([]byte(token), lro); err != nil { + return nil, err + } + return pollers.NewPoller(lro, nil, pl, eu), nil +} diff --git a/sdk/azcore/poller_test.go b/sdk/azcore/runtime/poller_test.go similarity index 97% rename from sdk/azcore/poller_test.go rename to sdk/azcore/runtime/poller_test.go index 36da6e548048..b7d69a1afe53 100644 --- a/sdk/azcore/poller_test.go +++ b/sdk/azcore/runtime/poller_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,9 +12,11 @@ import ( "fmt" "net/http" "net/url" + "reflect" "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -75,26 +77,18 @@ func TestNewPollerFromResumeTokenFail(t *testing.T) { } } -func TestOpPollerSimple(t *testing.T) { +func TestLocPollerSimple(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "status": "Succeeded"}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -111,28 +105,18 @@ func TestOpPollerSimple(t *testing.T) { } } -func TestOpPollerWithWidgetPUT(t *testing.T) { +func TestLocPollerWithWidget(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`)), mock.WithHeader("Retry-After", "1")) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) - // PUT and PATCH state that a final GET will happen - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 3}`))) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -148,34 +132,23 @@ func TestOpPollerWithWidgetPUT(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 2 { + if w.Size != 3 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { +func TestLocPollerCancelled(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) - // POST state that a final GET will happen from the URL provided in the Location header if available - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(`{"error": "cancelled"}`))) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -185,38 +158,32 @@ func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err != nil { - t.Fatal(err) + if err == nil { + t.Fatal("unexpected nil error") } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code %d", resp.StatusCode) + if _, ok := err.(pollerError); !ok { + t.Fatal("expected pollerError") } - if w.Size != 2 { + if resp != nil { + t.Fatal("expected nil response") + } + if w.Size != 0 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetPOST(t *testing.T) { +func TestLocPollerWithError(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - // POST with no location header means the success response returns the model - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded", "size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendError(errors.New("oops")) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -226,41 +193,32 @@ func TestOpPollerWithWidgetPOST(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err != nil { - t.Fatal(err) + if err == nil { + t.Fatal("unexpected nil error") } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code %d", resp.StatusCode) + if e := err.Error(); e != "oops" { + t.Fatalf("expected error %s", e) } - if w.Size != 2 { + if resp != nil { + t.Fatal("expected nil response") + } + if w.Size != 0 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetResourceLocation(t *testing.T) { +func TestLocPollerWithResumeToken(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte( - fmt.Sprintf(`{"status": "Succeeded", "resourceLocation": "%s"}`, srv.URL())))) - // final GET will happen from the URL provided in the resourceLocation - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) defer close() - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPatch, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -268,20 +226,70 @@ func TestOpPollerWithWidgetResourceLocation(t *testing.T) { if err != nil { t.Fatal(err) } - var w widget - resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) + resp, err := lro.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if lro.Done() { + t.Fatal("poller shouldn't be done yet") + } + resp, err = lro.FinalResponse(context.Background(), nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tk, err := lro.ResumeToken() + if err != nil { + t.Fatal(err) + } + lro, err = NewPollerFromResumeToken("fake.poller", tk, pl, errUnmarshall) + if err != nil { + t.Fatal(err) + } + resp, err = lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 2 { - t.Fatalf("unexpected widget size %d", w.Size) +} + +func TestLocPollerWithTimeout(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithSlowResponse(2 * time.Second)) + defer close() + + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + }, + } + pl := NewPipeline(srv) + lro, err := NewPoller("fake.poller", firstResp, pl, errUnmarshall) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + resp, err := lro.PollUntilDone(ctx, 5*time.Millisecond, nil) + cancel() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") } } -func TestOpPollerWithResumeToken(t *testing.T) { +func TestOpPollerSimple(t *testing.T) { srv, close := mock.NewServer() srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) @@ -293,13 +301,14 @@ func TestOpPollerWithResumeToken(t *testing.T) { t.Fatal(err) } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ "Operation-Location": []string{srv.URL()}, "Retry-After": []string{"1"}, }, Request: &http.Request{ - Method: http.MethodPut, + Method: http.MethodDelete, URL: reqURL, }, } @@ -308,32 +317,7 @@ func TestOpPollerWithResumeToken(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := lro.Poll(context.Background()) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusAccepted { - t.Fatalf("unexpected status code %d", resp.StatusCode) - } - if lro.Done() { - t.Fatal("poller shouldn't be done yet") - } - resp, err = lro.FinalResponse(context.Background(), nil) - if err == nil { - t.Fatal("unexpected nil error") - } - if resp != nil { - t.Fatal("expected nil response") - } - tk, err := lro.ResumeToken() - if err != nil { - t.Fatal(err) - } - lro, err = NewPollerFromResumeToken("fake.poller", tk, pl, errUnmarshall) - if err != nil { - t.Fatal(err) - } - resp, err = lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) + resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) if err != nil { t.Fatal(err) } @@ -342,18 +326,29 @@ func TestOpPollerWithResumeToken(t *testing.T) { } } -func TestLocPollerSimple(t *testing.T) { +func TestOpPollerWithWidgetPUT(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) + // PUT and PATCH state that a final GET will happen + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPut, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -361,27 +356,43 @@ func TestLocPollerSimple(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) + var w widget + resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } + if w.Size != 2 { + t.Fatalf("unexpected widget size %d", w.Size) + } } -func TestLocPollerWithWidget(t *testing.T) { +func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) + // POST state that a final GET will happen from the URL provided in the Location header if available + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 3}`))) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -397,23 +408,33 @@ func TestLocPollerWithWidget(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 3 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerCancelled(t *testing.T) { +func TestOpPollerWithWidgetPOST(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + // POST with no location header means the success response returns the model + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded", "size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(`{"error": "cancelled"}`))) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -423,32 +444,42 @@ func TestLocPollerCancelled(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err == nil { - t.Fatal("unexpected nil error") - } - if _, ok := err.(pollerError); !ok { - t.Fatal("expected pollerError") + if err != nil { + t.Fatal(err) } - if resp != nil { - t.Fatal("expected nil response") + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 0 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerWithError(t *testing.T) { +func TestOpPollerWithWidgetResourceLocation(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte( + fmt.Sprintf(`{"status": "Succeeded", "resourceLocation": "%s"}`, srv.URL())))) + // final GET will happen from the URL provided in the resourceLocation + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendError(errors.New("oops")) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPatch, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -458,32 +489,38 @@ func TestLocPollerWithError(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err == nil { - t.Fatal("unexpected nil error") - } - if e := err.Error(); e != "oops" { - t.Fatalf("expected error %s", e) + if err != nil { + t.Fatal(err) } - if resp != nil { - t.Fatal("expected nil response") + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 0 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerWithResumeToken(t *testing.T) { +func TestOpPollerWithResumeToken(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "status": "Succeeded"}`))) defer close() + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodDelete, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -525,35 +562,6 @@ func TestLocPollerWithResumeToken(t *testing.T) { } } -func TestLocPollerWithTimeout(t *testing.T) { - srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithSlowResponse(2 * time.Second)) - defer close() - - firstResp := &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - }, - } - pl := NewPipeline(srv) - lro, err := NewPoller("fake.poller", firstResp, pl, errUnmarshall) - if err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - resp, err := lro.PollUntilDone(ctx, 5*time.Millisecond, nil) - cancel() - if err == nil { - t.Fatal("unexpected nil error") - } - if resp != nil { - t.Fatal("expected nil response") - } -} - func TestNopPoller(t *testing.T) { firstResp := &http.Response{ StatusCode: http.StatusOK, @@ -563,8 +571,8 @@ func TestNopPoller(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := lro.lro.(*nopPoller); !ok { - t.Fatalf("unexpected poller type %T", lro.lro) + if pt := pollers.PollerType(lro); pt != reflect.TypeOf(&pollers.NopPoller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) } if !lro.Done() { t.Fatal("expected Done() for nopPoller") diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go new file mode 100644 index 000000000000..d72b68791c4e --- /dev/null +++ b/sdk/azcore/runtime/request.go @@ -0,0 +1,228 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "mime/multipart" + "reflect" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline = pipeline.Pipeline + +// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when +// encoding/decoding a slice of bytes to/from a string. +type Base64Encoding int + +const ( + // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. + Base64StdFormat Base64Encoding = 0 + + // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. + Base64URLFormat Base64Encoding = 1 +) + +// NewRequest creates a new policy.Request with the specified input. +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*pipeline.Request, error) { + return pipeline.NewRequest(ctx, httpMethod, endpoint) +} + +// NewPipeline creates a new Pipeline object from the specified Transport and Policies. +// If no transport is provided then the default *http.Client transport will be used. +func NewPipeline(transport pipeline.Transporter, policies ...pipeline.Policy) pipeline.Pipeline { + if transport == nil { + transport = defaultHTTPClient + } + // transport policy must always be the last in the slice + policies = append(policies, pipeline.PolicyFunc(httpHeaderPolicy), pipeline.PolicyFunc(bodyDownloadPolicy)) + return pipeline.NewPipeline(transport, policies...) +} + +// JoinPaths concatenates multiple URL path segments into one path, +// inserting path separation characters as required. JoinPaths will preserve +// query parameters in the root path +func JoinPaths(root string, paths ...string) string { + if len(paths) == 0 { + return root + } + + qps := "" + if strings.Contains(root, "?") { + splitPath := strings.Split(root, "?") + root, qps = splitPath[0], splitPath[1] + } + + for i := 0; i < len(paths); i++ { + root = strings.TrimRight(root, "/") + paths[i] = strings.TrimLeft(paths[i], "/") + root += "/" + paths[i] + } + + if qps != "" { + if !strings.HasSuffix(root, "/") { + root += "/" + } + return root + "?" + qps + } + return root +} + +// EncodeByteArray will base-64 encode the byte slice v. +func EncodeByteArray(v []byte, format Base64Encoding) string { + if format == Base64URLFormat { + return base64.RawURLEncoding.EncodeToString(v) + } + return base64.StdEncoding.EncodeToString(v) +} + +// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. +// The encoded value is treated as a JSON string. +func MarshalAsByteArray(req *policy.Request, v []byte, format Base64Encoding) error { + // send as a JSON string + encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) + return req.SetBody(shared.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON) +} + +// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. +func MarshalAsJSON(req *policy.Request, v interface{}) error { + v = cloneWithoutReadOnlyFields(v) + b, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON) +} + +// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. +func MarshalAsXML(req *policy.Request, v interface{}) error { + b, err := xml.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppXML) +} + +// SetMultipartFormData writes the specified keys/values as multi-part form +// fields with the specified value. File content must be specified as a ReadSeekCloser. +// All other values are treated as string values. +func SetMultipartFormData(req *policy.Request, formData map[string]interface{}) error { + body := bytes.Buffer{} + writer := multipart.NewWriter(&body) + for k, v := range formData { + if rsc, ok := v.(io.ReadSeekCloser); ok { + // this is the body to upload, the key is its file name + fd, err := writer.CreateFormFile(k, k) + if err != nil { + return err + } + // copy the data to the form file + if _, err = io.Copy(fd, rsc); err != nil { + return err + } + continue + } + // ensure the value is in string format + s, ok := v.(string) + if !ok { + s = fmt.Sprintf("%v", v) + } + if err := writer.WriteField(k, s); err != nil { + return err + } + } + if err := writer.Close(); err != nil { + return err + } + return req.SetBody(shared.NopCloser(bytes.NewReader(body.Bytes())), writer.FormDataContentType()) +} + +// returns a clone of the object graph pointed to by v, omitting values of all read-only +// fields. if there are no read-only fields in the object graph, no clone is created. +func cloneWithoutReadOnlyFields(v interface{}) interface{} { + val := reflect.Indirect(reflect.ValueOf(v)) + if val.Kind() != reflect.Struct { + // not a struct, skip + return v + } + // first walk the graph to find any R/O fields. + // if there aren't any, skip cloning the graph. + if !recursiveFindReadOnlyField(val) { + return v + } + return recursiveCloneWithoutReadOnlyFields(val) +} + +// returns true if any field in the object graph of val contains the `azure:"ro"` tag value +func recursiveFindReadOnlyField(val reflect.Value) bool { + t := val.Type() + // iterate over the fields, looking for the "azure" tag. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + aztag := field.Tag.Get("azure") + if azureTagIsReadOnly(aztag) { + return true + } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) { + return true + } + } + return false +} + +// clones the object graph of val. all non-R/O properties are copied to the clone +func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} { + clone := reflect.New(val.Type()) + t := val.Type() + // iterate over the fields, looking for the "azure" tag. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + aztag := field.Tag.Get("azure") + if azureTagIsReadOnly(aztag) { + // omit from payload + } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct { + // recursive case + v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i))) + if t.Field(i).Anonymous { + // NOTE: this does not handle the case of embedded fields of unexported struct types. + // this should be ok as we don't generate any code like this at present + reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v))) + } else { + reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v)) + } + } else { + // no azure RO tag, non-recursive case, include in payload + reflect.Indirect(clone).Field(i).Set(val.Field(i)) + } + } + return clone.Interface() +} + +// returns true if the "azure" tag contains the option "ro" +func azureTagIsReadOnly(tag string) bool { + if tag == "" { + return false + } + parts := strings.Split(tag, ",") + for _, part := range parts { + if part == "ro" { + return true + } + } + return false +} diff --git a/sdk/azcore/request_test.go b/sdk/azcore/runtime/request_test.go similarity index 82% rename from sdk/azcore/request_test.go rename to sdk/azcore/runtime/request_test.go index 98c80e7c795e..aa1bb59a14a9 100644 --- a/sdk/azcore/request_test.go +++ b/sdk/azcore/runtime/request_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -20,6 +20,8 @@ import ( "strings" "testing" "unsafe" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) type testJSON struct { @@ -37,17 +39,17 @@ func TestRequestMarshalXML(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.MarshalAsXML(testXML{SomeInt: 1, SomeString: "s"}) + err = MarshalAsXML(req, testXML{SomeInt: 1, SomeString: "s"}) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppXML { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppXML) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppXML { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppXML) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } } @@ -71,17 +73,17 @@ func TestRequestMarshalJSON(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.MarshalAsJSON(testJSON{SomeInt: 1, SomeString: "s"}) + err = MarshalAsJSON(req, testJSON{SomeInt: 1, SomeString: "s"}) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } } @@ -92,20 +94,20 @@ func TestRequestMarshalAsByteArrayURLFormat(t *testing.T) { t.Fatal(err) } const payload = "a string that gets encoded with base64url" - err = req.MarshalAsByteArray([]byte(payload), Base64URLFormat) + err = MarshalAsByteArray(req, []byte(payload), Base64URLFormat) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -120,20 +122,20 @@ func TestRequestMarshalAsByteArrayStdFormat(t *testing.T) { t.Fatal(err) } const payload = "a string that gets encoded with base64url" - err = req.MarshalAsByteArray([]byte(payload), Base64StdFormat) + err = MarshalAsByteArray(req, []byte(payload), Base64StdFormat) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -337,11 +339,11 @@ func TestCloneWithoutReadOnlyFieldsEndToEnd(t *testing.T) { ID: &id, Name: &name, } - err = req.MarshalAsJSON(nro) + err = MarshalAsJSON(req, nro) if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -476,36 +478,12 @@ func TestRequestSetBodyContentLengthHeader(t *testing.T) { for i := 0; i < buffLen; i++ { buff[i] = 1 } - err = req.SetBody(NopCloser(bytes.NewReader(buff)), "application/octet-stream") + err = req.SetBody(shared.NopCloser(bytes.NewReader(buff)), "application/octet-stream") if err != nil { t.Fatal(err) } - if req.Header.Get(headerContentLength) != strconv.FormatInt(buffLen, 10) { - t.Fatalf("expected content-length %d, got %s", buffLen, req.Header.Get(headerContentLength)) - } -} - -func TestNewRequestFail(t *testing.T) { - req, err := NewRequest(context.Background(), http.MethodOptions, "://test.contoso.com/") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") - } - req, err = NewRequest(context.Background(), http.MethodPatch, "/missing/the/host") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") - } - req, err = NewRequest(context.Background(), http.MethodPatch, "mailto://nobody.contoso.com") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") + if req.Raw().Header.Get(shared.HeaderContentLength) != strconv.FormatInt(buffLen, 10) { + t.Fatalf("expected content-length %d, got %s", buffLen, req.Raw().Header.Get(shared.HeaderContentLength)) } } @@ -529,7 +507,7 @@ func TestRequestValidFail(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("inval d", "header") + req.Raw().Header.Add("inval d", "header") p := NewPipeline(nil) resp, err := p.Do(req) if err == nil { @@ -538,9 +516,9 @@ func TestRequestValidFail(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - req.Header = http.Header{} + req.Raw().Header = http.Header{} // the string "null\0" - req.Header.Add("invalid", string([]byte{0x6e, 0x75, 0x6c, 0x6c, 0x0})) + req.Raw().Header.Add("invalid", string([]byte{0x6e, 0x75, 0x6c, 0x6c, 0x0})) resp, err = p.Do(req) if err == nil { t.Fatal("unexpected nil error") @@ -555,22 +533,22 @@ func TestSetMultipartFormData(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.SetMultipartFormData(map[string]interface{}{ + err = SetMultipartFormData(req, map[string]interface{}{ "string": "value", "int": 1, - "data": NopCloser(strings.NewReader("some data")), + "data": shared.NopCloser(strings.NewReader("some data")), }) if err != nil { t.Fatal(err) } - mt, params, err := mime.ParseMediaType(req.Header.Get(headerContentType)) + mt, params, err := mime.ParseMediaType(req.Raw().Header.Get(shared.HeaderContentType)) if err != nil { t.Fatal(err) } if mt != "multipart/form-data" { t.Fatalf("unexpected media type %s", mt) } - reader := multipart.NewReader(req.Body, params["boundary"]) + reader := multipart.NewReader(req.Raw().Body, params["boundary"]) for { part, err := reader.NextPart() if err == io.EOF { diff --git a/sdk/azcore/response.go b/sdk/azcore/runtime/response.go similarity index 76% rename from sdk/azcore/response.go rename to sdk/azcore/runtime/response.go index 59cb195c417f..c0a990e8aa19 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/runtime/response.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -16,9 +16,10 @@ import ( "io/ioutil" "net/http" "sort" - "strconv" "strings" - "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) // Payload reads and returns the response body or an error. @@ -39,15 +40,7 @@ func Payload(resp *http.Response) ([]byte, error) { // HasStatusCode returns true if the Response's status code is one of the specified values. func HasStatusCode(resp *http.Response, statusCodes ...int) bool { - if resp == nil { - return false - } - for _, sc := range statusCodes { - if resp.StatusCode == sc { - return true - } - } - return false + return shared.HasStatusCode(resp, statusCodes...) } // UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. @@ -123,47 +116,6 @@ func removeBOM(resp *http.Response) error { return nil } -// writes to a buffer, used for logging purposes -func writeBody(resp *http.Response, b *bytes.Buffer) error { - ct := resp.Header.Get(headerContentType) - if ct == "" { - fmt.Fprint(b, " Response contained no body\n") - return nil - } else if !shouldLogBody(b, ct) { - return nil - } - body, err := Payload(resp) - if err != nil { - fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) - return err - } - if len(body) > 0 { - logBody(b, body) - } else { - fmt.Fprint(b, " Response contained no body\n") - } - return nil -} - -// RetryAfter returns non-zero if the response contains a Retry-After header value. -func RetryAfter(resp *http.Response) time.Duration { - if resp == nil { - return 0 - } - ra := resp.Header.Get(headerRetryAfter) - if ra == "" { - return 0 - } - // retry-after values are expressed in either number of - // seconds or an HTTP-date indicating when to try again - if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { - return time.Duration(retryAfter) * time.Second - } else if t, err := time.Parse(time.RFC1123, ra); err == nil { - return time.Until(t) - } - return 0 -} - // DecodeByteArray will base-64 decode the provided string into v. func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { if len(s) == 0 { @@ -197,10 +149,10 @@ func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { // writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are // not nil, then these are also written into the Buffer. -func writeRequestWithResponse(b *bytes.Buffer, request *Request, resp *http.Response, err error) { +func writeRequestWithResponse(b *bytes.Buffer, req *policy.Request, resp *http.Response, err error) { // Write the request into the buffer. - fmt.Fprint(b, " "+request.Method+" "+request.URL.String()+"\n") - writeHeader(b, request.Header) + fmt.Fprint(b, " "+req.Raw().Method+" "+req.Raw().URL.String()+"\n") + writeHeader(b, req.Raw().Header) if resp != nil { fmt.Fprintln(b, " --------------------------------------------------------------------------------") fmt.Fprint(b, " RESPONSE Status: "+resp.Status+"\n") diff --git a/sdk/azcore/response_test.go b/sdk/azcore/runtime/response_test.go similarity index 88% rename from sdk/azcore/response_test.go rename to sdk/azcore/runtime/response_test.go index 9e325fa048bc..cfd867a74379 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/runtime/response_test.go @@ -4,13 +4,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" "net/http" "testing" - "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -152,33 +151,6 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) { } } -func TestRetryAfter(t *testing.T) { - resp := &http.Response{ - Header: http.Header{}, - } - if d := RetryAfter(resp); d > 0 { - t.Fatalf("unexpected retry-after value %d", d) - } - resp.Header.Set(headerRetryAfter, "300") - d := RetryAfter(resp) - if d <= 0 { - t.Fatal("expected retry-after value from seconds") - } - if d != 300*time.Second { - t.Fatalf("expected 300 seconds, got %d", d/time.Second) - } - atDate := time.Now().Add(600 * time.Second) - resp.Header.Set(headerRetryAfter, atDate.Format(time.RFC1123)) - d = RetryAfter(resp) - if d <= 0 { - t.Fatal("expected retry-after value from date") - } - // d will not be exactly 600 seconds but it will be close - if s := d / time.Second; s < 598 || s > 602 { - t.Fatalf("expected ~600 seconds, got %d", s) - } -} - func TestResponseUnmarshalAsByteArrayURLFormat(t *testing.T) { srv, close := mock.NewServer() defer close() diff --git a/sdk/azcore/runtime/transport_default_http_client.go b/sdk/azcore/runtime/transport_default_http_client.go new file mode 100644 index 000000000000..4352f916c455 --- /dev/null +++ b/sdk/azcore/runtime/transport_default_http_client.go @@ -0,0 +1,35 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "crypto/tls" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +var defaultHTTPClient *http.Client + +func init() { + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 + defaultHTTPClient = &http.Client{ + Transport: defaultTransport, + } +} + +// AuthenticationOptions contains various options used to create a credential policy. +type AuthenticationOptions struct { + // TokenRequest is a TokenRequestOptions that includes a scopes field which contains + // the list of OAuth2 authentication scopes used when requesting a token. + // This field is ignored for other forms of authentication (e.g. shared key). + TokenRequest policy.TokenRequestOptions + // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate + // in cross-tenant applications. + AuxiliaryTenants []string +} diff --git a/sdk/azcore/progress.go b/sdk/azcore/streaming/progress.go similarity index 78% rename from sdk/azcore/progress.go rename to sdk/azcore/streaming/progress.go index cfdd2bf1d902..ca0b05c80812 100644 --- a/sdk/azcore/progress.go +++ b/sdk/azcore/streaming/progress.go @@ -4,21 +4,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package streaming import ( "io" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) type progress struct { rc io.ReadCloser - rsc ReadSeekCloser + rsc io.ReadSeekCloser pr func(bytesTransferred int64) offset int64 } +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return shared.NopCloser(rs) +} + // NewRequestProgress adds progress reporting to an HTTP request's body stream. -func NewRequestProgress(body ReadSeekCloser, pr func(bytesTransferred int64)) ReadSeekCloser { +func NewRequestProgress(body io.ReadSeekCloser, pr func(bytesTransferred int64)) io.ReadSeekCloser { return &progress{ rc: body, rsc: body, diff --git a/sdk/azcore/progress_test.go b/sdk/azcore/streaming/progress_test.go similarity index 89% rename from sdk/azcore/progress_test.go rename to sdk/azcore/streaming/progress_test.go index bcf6e7abde14..cf68bdf5eb19 100644 --- a/sdk/azcore/progress_test.go +++ b/sdk/azcore/streaming/progress_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package streaming import ( "bytes" @@ -15,6 +15,7 @@ import ( "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -28,8 +29,8 @@ func TestProgressReporting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := NewPipeline(srv) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -77,8 +78,8 @@ func TestProgressReportingSeek(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := NewPipeline(srv) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/sdk/azcore/to/to.go b/sdk/azcore/to/to.go new file mode 100644 index 000000000000..01bb033ef03c --- /dev/null +++ b/sdk/azcore/to/to.go @@ -0,0 +1,107 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package to + +import "time" + +// BoolPtr returns a pointer to the provided bool. +func BoolPtr(b bool) *bool { + return &b +} + +// Float32Ptr returns a pointer to the provided float32. +func Float32Ptr(i float32) *float32 { + return &i +} + +// Float64Ptr returns a pointer to the provided float64. +func Float64Ptr(i float64) *float64 { + return &i +} + +// Int32Ptr returns a pointer to the provided int32. +func Int32Ptr(i int32) *int32 { + return &i +} + +// Int64Ptr returns a pointer to the provided int64. +func Int64Ptr(i int64) *int64 { + return &i +} + +// StringPtr returns a pointer to the provided string. +func StringPtr(s string) *string { + return &s +} + +// TimePtr returns a pointer to the provided time.Time. +func TimePtr(t time.Time) *time.Time { + return &t +} + +// Int32PtrArray returns an array of *int32 from the specified values. +func Int32PtrArray(vals ...int32) []*int32 { + arr := make([]*int32, len(vals)) + for i := range vals { + arr[i] = Int32Ptr(vals[i]) + } + return arr +} + +// Int64PtrArray returns an array of *int64 from the specified values. +func Int64PtrArray(vals ...int64) []*int64 { + arr := make([]*int64, len(vals)) + for i := range vals { + arr[i] = Int64Ptr(vals[i]) + } + return arr +} + +// Float32PtrArray returns an array of *float32 from the specified values. +func Float32PtrArray(vals ...float32) []*float32 { + arr := make([]*float32, len(vals)) + for i := range vals { + arr[i] = Float32Ptr(vals[i]) + } + return arr +} + +// Float64PtrArray returns an array of *float64 from the specified values. +func Float64PtrArray(vals ...float64) []*float64 { + arr := make([]*float64, len(vals)) + for i := range vals { + arr[i] = Float64Ptr(vals[i]) + } + return arr +} + +// BoolPtrArray returns an array of *bool from the specified values. +func BoolPtrArray(vals ...bool) []*bool { + arr := make([]*bool, len(vals)) + for i := range vals { + arr[i] = BoolPtr(vals[i]) + } + return arr +} + +// StringPtrArray returns an array of *string from the specified values. +func StringPtrArray(vals ...string) []*string { + arr := make([]*string, len(vals)) + for i := range vals { + arr[i] = StringPtr(vals[i]) + } + return arr +} + +// TimePtrArray returns an array of *time.Time from the specified values. +func TimePtrArray(vals ...time.Time) []*time.Time { + arr := make([]*time.Time, len(vals)) + for i := range vals { + arr[i] = TimePtr(vals[i]) + } + return arr +} diff --git a/sdk/azcore/to/to_test.go b/sdk/azcore/to/to_test.go new file mode 100644 index 000000000000..ef9374b0ceda --- /dev/null +++ b/sdk/azcore/to/to_test.go @@ -0,0 +1,192 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package to + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" +) + +func TestBoolPtr(t *testing.T) { + b := true + pb := BoolPtr(b) + if pb == nil { + t.Fatal("unexpected nil conversion") + } + if *pb != b { + t.Fatalf("got %v, want %v", *pb, b) + } +} + +func TestFloat32Ptr(t *testing.T) { + f32 := float32(3.1415926) + pf32 := Float32Ptr(f32) + if pf32 == nil { + t.Fatal("unexpected nil conversion") + } + if *pf32 != f32 { + t.Fatalf("got %v, want %v", *pf32, f32) + } +} + +func TestFloat64Ptr(t *testing.T) { + f64 := float64(2.71828182845904) + pf64 := Float64Ptr(f64) + if pf64 == nil { + t.Fatal("unexpected nil conversion") + } + if *pf64 != f64 { + t.Fatalf("got %v, want %v", *pf64, f64) + } +} + +func TestInt32Ptr(t *testing.T) { + i32 := int32(123456789) + pi32 := Int32Ptr(i32) + if pi32 == nil { + t.Fatal("unexpected nil conversion") + } + if *pi32 != i32 { + t.Fatalf("got %v, want %v", *pi32, i32) + } +} + +func TestInt64Ptr(t *testing.T) { + i64 := int64(9876543210) + pi64 := Int64Ptr(i64) + if pi64 == nil { + t.Fatal("unexpected nil conversion") + } + if *pi64 != i64 { + t.Fatalf("got %v, want %v", *pi64, i64) + } +} + +func TestStringPtr(t *testing.T) { + s := "the string" + ps := StringPtr(s) + if ps == nil { + t.Fatal("unexpected nil conversion") + } + if *ps != s { + t.Fatalf("got %v, want %v", *ps, s) + } +} + +func TestTimePtr(t *testing.T) { + tt := time.Now() + pt := TimePtr(tt) + if pt == nil { + t.Fatal("unexpected nil conversion") + } + if *pt != tt { + t.Fatalf("got %v, want %v", *pt, tt) + } +} + +func TestInt32PtrArray(t *testing.T) { + arr := Int32PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Int32PtrArray(1, 2, 3, 4, 5) + for i, v := range arr { + if *v != int32(i+1) { + t.Fatal("values don't match") + } + } +} + +func TestInt64PtrArray(t *testing.T) { + arr := Int64PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Int64PtrArray(1, 2, 3, 4, 5) + for i, v := range arr { + if *v != int64(i+1) { + t.Fatal("values don't match") + } + } +} + +func TestFloat32PtrArray(t *testing.T) { + arr := Float32PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Float32PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) + for i, v := range arr { + f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 32) + if err != nil { + t.Fatal(err) + } + if *v != float32(f) { + t.Fatal("values don't match") + } + } +} + +func TestFloat64PtrArray(t *testing.T) { + arr := Float64PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Float64PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) + for i, v := range arr { + f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 64) + if err != nil { + t.Fatal(err) + } + if *v != f { + t.Fatal("values don't match") + } + } +} + +func TestBoolPtrArray(t *testing.T) { + arr := BoolPtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = BoolPtrArray(true, false, true) + curr := true + for _, v := range arr { + if *v != curr { + t.Fatal("values don'p match") + } + curr = !curr + } +} + +func TestStringPtrArray(t *testing.T) { + arr := StringPtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = StringPtrArray("one", "", "three") + if !reflect.DeepEqual(arr, []*string{StringPtr("one"), StringPtr(""), StringPtr("three")}) { + t.Fatal("values don't match") + } +} + +func TestTimePtrArray(t *testing.T) { + arr := TimePtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + t1 := time.Now() + t2 := time.Time{} + t3 := t1.Add(24 * time.Hour) + arr = TimePtrArray(t1, t2, t3) + if !reflect.DeepEqual(arr, []*time.Time{&t1, &t2, &t3}) { + t.Fatal("values don't match") + } +} diff --git a/sdk/azcore/transport_default_http_client.go b/sdk/azcore/transport_default_http_client.go deleted file mode 100644 index 02a36bcbe741..000000000000 --- a/sdk/azcore/transport_default_http_client.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "crypto/tls" - "net/http" -) - -var defaultHTTPClient *http.Client - -func init() { - defaultTransport := http.DefaultTransport.(*http.Transport).Clone() - defaultTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 - defaultHTTPClient = &http.Client{ - Transport: defaultTransport, - } -} diff --git a/sdk/azcore/version.go b/sdk/azcore/version.go deleted file mode 100644 index 2220616518bc..000000000000 --- a/sdk/azcore/version.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -const ( - // UserAgent is the string to be used in the user agent string when making requests. - UserAgent = "azcore/" + Version - - // Version is the semantic version (see http://semver.org) of this module. - Version = "v0.18.0" -)