diff --git a/sdk/armcore/internal/pollers/async/async.go b/sdk/armcore/internal/pollers/async/async.go index acfe81b1ac7a..5e25eef930ef 100644 --- a/sdk/armcore/internal/pollers/async/async.go +++ b/sdk/armcore/internal/pollers/async/async.go @@ -7,6 +7,7 @@ package async import ( "errors" + "fmt" "net/http" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -42,6 +43,9 @@ func New(resp *azcore.Response, finalState string, pollerID string) (*Poller, er if asyncURL == "" { return nil, errors.New("response is missing Azure-AsyncOperation header") } + if !pollers.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } p := &Poller{ Type: pollers.MakeID(pollerID, "async"), AsyncURL: asyncURL, diff --git a/sdk/armcore/internal/pollers/loc/loc.go b/sdk/armcore/internal/pollers/loc/loc.go index c01b5b33f425..853410bd66ea 100644 --- a/sdk/armcore/internal/pollers/loc/loc.go +++ b/sdk/armcore/internal/pollers/loc/loc.go @@ -6,6 +6,7 @@ package loc import ( + "fmt" "net/http" "github.com/Azure/azure-sdk-for-go/sdk/armcore/internal/pollers" @@ -27,9 +28,13 @@ type Poller struct { // New creates a new Poller from the provided initial response. func New(resp *azcore.Response, pollerID string) (*Poller, error) { azcore.Log().Write(azcore.LogLongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(pollers.HeaderLocation) + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } p := &Poller{ Type: pollers.MakeID(pollerID, "loc"), - PollURL: resp.Header.Get(pollers.HeaderLocation), + PollURL: locURL, CurState: "InProgress", } return p, nil diff --git a/sdk/armcore/internal/pollers/pollers.go b/sdk/armcore/internal/pollers/pollers.go index 0ee2c0c0e966..0f0081c80012 100644 --- a/sdk/armcore/internal/pollers/pollers.go +++ b/sdk/armcore/internal/pollers/pollers.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/url" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -108,6 +109,12 @@ func GetProvisioningState(resp *azcore.Response) (string, error) { return "", ErrNoProvisioningState } +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + // MakeID returns the unique poller identifier in the format pollerID;poller. func MakeID(pollerID string, kind string) string { return fmt.Sprintf("%s;%s", pollerID, kind) diff --git a/sdk/armcore/internal/pollers/pollers_test.go b/sdk/armcore/internal/pollers/pollers_test.go index 80d06df2e7f6..6a502df99ba9 100644 --- a/sdk/armcore/internal/pollers/pollers_test.go +++ b/sdk/armcore/internal/pollers/pollers_test.go @@ -109,3 +109,12 @@ func TestMakeID(t *testing.T) { t.Fatalf("unexpected poller kind %s", p) } } + +func TestIsValidURL(t *testing.T) { + if IsValidURL("/foo") { + t.Fatal("unexpected valid URL") + } + if !IsValidURL("https://foo.bar/baz") { + t.Fatal("expected valid URL") + } +}