From 937708ba1b4e02522901d31121879b02dccb9def Mon Sep 17 00:00:00 2001 From: Catalina Peralta Date: Wed, 22 Jul 2020 13:21:46 -0700 Subject: [PATCH 1/2] Adding azidentity --- sdk/azidentity/aad_identity_client.go | 432 ++++++++++++++++++ sdk/azidentity/aad_identity_client_test.go | 37 ++ sdk/azidentity/azidentity.go | 210 +++++++++ sdk/azidentity/azidentity_test.go | 134 ++++++ sdk/azidentity/azure_cli_credential.go | 176 +++++++ sdk/azidentity/azure_cli_credential_test.go | 76 +++ sdk/azidentity/bearer_token_policy.go | 107 +++++ sdk/azidentity/bearer_token_policy_test.go | 156 +++++++ sdk/azidentity/chained_token_credential.go | 77 ++++ .../chained_token_credential_test.go | 181 ++++++++ .../client_certificate_credential.go | 59 +++ .../client_certificate_credential_test.go | 193 ++++++++ sdk/azidentity/client_secret_credential.go | 55 +++ .../client_secret_credential_test.go | 140 ++++++ sdk/azidentity/default_azure_credential.go | 62 +++ .../default_azure_credential_test.go | 90 ++++ sdk/azidentity/device_code_credential.go | 112 +++++ sdk/azidentity/device_code_credential_test.go | 291 ++++++++++++ sdk/azidentity/environment_credential.go | 38 ++ sdk/azidentity/environment_credential_test.go | 111 +++++ sdk/azidentity/fingerprint.go | 79 ++++ sdk/azidentity/go.mod | 8 + sdk/azidentity/go.sum | 4 + sdk/azidentity/jwt.go | 88 ++++ sdk/azidentity/logging.go | 76 +++ sdk/azidentity/managed_identity_client.go | 260 +++++++++++ .../managed_identity_client_test.go | 20 + sdk/azidentity/managed_identity_credential.go | 92 ++++ .../managed_identity_credential_test.go | 279 +++++++++++ sdk/azidentity/testdata/certificate.pem | 49 ++ sdk/azidentity/testdata/certificate_empty.pem | 21 + .../testdata/certificate_formatA.pem | 49 ++ .../testdata/certificate_formatB.pem | 49 ++ sdk/azidentity/testdata/certificate_nokey.pem | 21 + .../username_password_credential.go | 56 +++ .../username_password_credential_test.go | 115 +++++ 36 files changed, 4003 insertions(+) create mode 100644 sdk/azidentity/aad_identity_client.go create mode 100644 sdk/azidentity/aad_identity_client_test.go create mode 100644 sdk/azidentity/azidentity.go create mode 100644 sdk/azidentity/azidentity_test.go create mode 100644 sdk/azidentity/azure_cli_credential.go create mode 100644 sdk/azidentity/azure_cli_credential_test.go create mode 100644 sdk/azidentity/bearer_token_policy.go create mode 100644 sdk/azidentity/bearer_token_policy_test.go create mode 100644 sdk/azidentity/chained_token_credential.go create mode 100644 sdk/azidentity/chained_token_credential_test.go create mode 100644 sdk/azidentity/client_certificate_credential.go create mode 100644 sdk/azidentity/client_certificate_credential_test.go create mode 100644 sdk/azidentity/client_secret_credential.go create mode 100644 sdk/azidentity/client_secret_credential_test.go create mode 100644 sdk/azidentity/default_azure_credential.go create mode 100644 sdk/azidentity/default_azure_credential_test.go create mode 100644 sdk/azidentity/device_code_credential.go create mode 100644 sdk/azidentity/device_code_credential_test.go create mode 100644 sdk/azidentity/environment_credential.go create mode 100644 sdk/azidentity/environment_credential_test.go create mode 100644 sdk/azidentity/fingerprint.go create mode 100644 sdk/azidentity/go.mod create mode 100644 sdk/azidentity/go.sum create mode 100644 sdk/azidentity/jwt.go create mode 100644 sdk/azidentity/logging.go create mode 100644 sdk/azidentity/managed_identity_client.go create mode 100644 sdk/azidentity/managed_identity_client_test.go create mode 100644 sdk/azidentity/managed_identity_credential.go create mode 100644 sdk/azidentity/managed_identity_credential_test.go create mode 100644 sdk/azidentity/testdata/certificate.pem create mode 100644 sdk/azidentity/testdata/certificate_empty.pem create mode 100644 sdk/azidentity/testdata/certificate_formatA.pem create mode 100644 sdk/azidentity/testdata/certificate_formatB.pem create mode 100644 sdk/azidentity/testdata/certificate_nokey.pem create mode 100644 sdk/azidentity/username_password_credential.go create mode 100644 sdk/azidentity/username_password_credential_test.go diff --git a/sdk/azidentity/aad_identity_client.go b/sdk/azidentity/aad_identity_client.go new file mode 100644 index 000000000000..04918743a9dc --- /dev/null +++ b/sdk/azidentity/aad_identity_client.go @@ -0,0 +1,432 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "bufio" + "context" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "path" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + clientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + tokenEndpoint = "/oauth2/v2.0/token/" +) + +const ( + qpClientAssertionType = "client_assertion_type" + qpClientAssertion = "client_assertion" + qpClientID = "client_id" + qpClientSecret = "client_secret" + qpDeviceCode = "device_code" + qpGrantType = "grant_type" + qpPassword = "password" + qpRefreshToken = "refresh_token" + qpResponseType = "response_type" + qpScope = "scope" + qpUsername = "username" +) + +// aadIdentityClient provides the base for authenticating with Client Secret Credentials, Client Certificate Credentials +// and Environment Credentials. This type inlcudes an azcore.Pipeline and TokenCredentialOptions. +type aadIdentityClient struct { + options TokenCredentialOptions + pipeline azcore.Pipeline +} + +// newAADIdentityClient creates a new instance of the aadIdentityClient with the TokenCredentialOptions +// that are passed into it along with a default pipeline. +// options: TokenCredentialOptions that can configure policies for the pipeline and the authority host that +// will be used to retrieve tokens and authenticate +func newAADIdentityClient(options *TokenCredentialOptions) (*aadIdentityClient, error) { + logEnvVars() + options, err := options.setDefaultValues() + if err != nil { + return nil, err + } + return &aadIdentityClient{options: *options, pipeline: newDefaultPipeline(*options)}, nil +} + +// refreshAccessToken creates a refresh token request and returns the resulting Access Token or +// an error in case of an authentication failure. +// ctx: The current request context +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal +// clientID: The client (application) ID of the service principal +// clientSecret: A client secret that was generated for the App Registration used to authenticate the client +// scopes: The scopes for the given access token +func (c *aadIdentityClient) refreshAccessToken(ctx context.Context, tenantID string, clientID string, clientSecret string, refreshToken string, scopes []string) (*tokenResponse, error) { + msg, err := c.createRefreshTokenRequest(tenantID, clientID, clientSecret, refreshToken, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createRefreshAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +// authenticate creates a client secret authentication request and returns the resulting Access Token or +// an error in case of authentication failure. +// ctx: The current request context +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal +// clientID: The client (application) ID of the service principal +// clientSecret: A client secret that was generated for the App Registration used to authenticate the client +// scopes: The scopes required for the token +func (c *aadIdentityClient) authenticate(ctx context.Context, tenantID string, clientID string, clientSecret string, scopes []string) (*azcore.AccessToken, error) { + msg, err := c.createClientSecretAuthRequest(tenantID, clientID, clientSecret, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +// authenticateCertificate creates a client certificate authentication request and returns an Access Token or +// an error. +// ctx: The current request context +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal +// clientID: The client (application) ID of the service principal +// clientCertificatePath: The path to the client certificate PEM file +// scopes: The scopes required for the token +func (c *aadIdentityClient) authenticateCertificate(ctx context.Context, tenantID string, clientID string, clientCertificatePath string, scopes []string) (*azcore.AccessToken, error) { + msg, err := c.createClientCertificateAuthRequest(tenantID, clientID, clientCertificatePath, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +func (c *aadIdentityClient) createAccessToken(res *azcore.Response) (*azcore.AccessToken, error) { + value := struct { + Token string `json:"access_token"` + ExpiresIn json.Number `json:"expires_in"` + ExpiresOn string `json:"expires_on"` + }{} + if err := res.UnmarshalAsJSON(&value); err != nil { + return nil, fmt.Errorf("internal AccessToken: %w", err) + } + t, err := value.ExpiresIn.Int64() + if err != nil { + return nil, err + } + return &azcore.AccessToken{ + Token: value.Token, + ExpiresOn: time.Now().Add(time.Second * time.Duration(t)).UTC(), + }, nil +} + +func (c *aadIdentityClient) createRefreshAccessToken(res *azcore.Response) (*tokenResponse, error) { + // To know more about refreshing access tokens please see: https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-protocols-oauth-code#refreshing-the-access-tokens + // DeviceCodeCredential uses refresh token, please see the authentication flow here: https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code + value := struct { + Token string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn json.Number `json:"expires_in"` + ExpiresOn string `json:"expires_on"` + }{} + if err := res.UnmarshalAsJSON(&value); err != nil { + return nil, fmt.Errorf("internal AccessToken: %w", err) + } + t, err := value.ExpiresIn.Int64() + if err != nil { + return nil, err + } + accessToken := &azcore.AccessToken{ + Token: value.Token, + ExpiresOn: time.Now().Add(time.Second * time.Duration(t)).UTC(), + } + return &tokenResponse{token: accessToken, refreshToken: value.RefreshToken}, nil +} + +func (c *aadIdentityClient) createRefreshTokenRequest(tenantID, clientID, clientSecret, refreshToken string, scopes []string) (*azcore.Request, error) { + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, tokenEndpoint) + data := url.Values{} + data.Set(qpGrantType, "refresh_token") + data.Set(qpClientID, clientID) + // clientSecret is only required for web apps. To know more about refreshing access tokens please see: https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-protocols-oauth-code#refreshing-the-access-tokens + if len(clientSecret) != 0 { + data.Set(qpClientSecret, clientSecret) + } + data.Set(qpRefreshToken, refreshToken) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err := req.SetBody(body) + if err != nil { + return nil, err + } + return req, nil +} + +func (c *aadIdentityClient) createClientSecretAuthRequest(tenantID string, clientID string, clientSecret string, scopes []string) (*azcore.Request, error) { + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, tokenEndpoint) + data := url.Values{} + data.Set(qpGrantType, "client_credentials") + data.Set(qpClientID, clientID) + data.Set(qpClientSecret, clientSecret) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err := req.SetBody(body) + if err != nil { + return nil, err + } + + return req, nil +} + +func (c *aadIdentityClient) createClientCertificateAuthRequest(tenantID string, clientID string, clientCertificate string, scopes []string) (*azcore.Request, error) { + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, tokenEndpoint) + clientAssertion, err := createClientAssertionJWT(clientID, u.String(), clientCertificate) + if err != nil { + return nil, err + } + data := url.Values{} + data.Set(qpGrantType, "client_credentials") + data.Set(qpResponseType, "token") + data.Set(qpClientID, clientID) + data.Set(qpClientAssertionType, clientAssertionType) + data.Set(qpClientAssertion, clientAssertion) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err = req.SetBody(body) + if err != nil { + return nil, err + } + return req, nil +} + +// authenticateUsernamePassword creates a client username and password authentication request and returns an Access Token or +// an error. +// ctx: The current request context +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal +// clientID: The client (application) ID of the service principal +// username: User's account username +// password: User's account password +// scopes: The scopes required for the token +func (c *aadIdentityClient) authenticateUsernamePassword(ctx context.Context, tenantID string, clientID string, username string, password string, scopes []string) (*azcore.AccessToken, error) { + msg, err := c.createUsernamePasswordAuthRequest(tenantID, clientID, username, password, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +func (c *aadIdentityClient) createUsernamePasswordAuthRequest(tenantID string, clientID string, username string, password string, scopes []string) (*azcore.Request, error) { + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, tokenEndpoint) + data := url.Values{} + data.Set(qpResponseType, "token") + data.Set(qpGrantType, "password") + data.Set(qpClientID, clientID) + data.Set(qpUsername, username) + data.Set(qpPassword, password) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err := req.SetBody(body) + if err != nil { + return nil, err + } + return req, nil +} + +func createDeviceCodeResult(res *azcore.Response) (*deviceCodeResult, error) { + value := &deviceCodeResult{} + if err := res.UnmarshalAsJSON(&value); err != nil { + return nil, fmt.Errorf("DeviceCodeResult: %w", err) + } + return value, nil +} + +// authenticateDeviceCode creates a device code authentication request and returns an Access Token or +// an error in case of failure. +// ctx: The current request context +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal +// clientID: The client (application) ID of the service principal +// deviceCode: The device code associated with the request +// scopes: The scopes required for the token +func (c *aadIdentityClient) authenticateDeviceCode(ctx context.Context, tenantID string, clientID string, deviceCode string, scopes []string) (*tokenResponse, error) { + msg, err := c.createDeviceCodeAuthRequest(tenantID, clientID, deviceCode, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createRefreshAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +func (c *aadIdentityClient) createDeviceCodeAuthRequest(tenantID string, clientID string, deviceCode string, scopes []string) (*azcore.Request, error) { + if len(tenantID) == 0 { // if the user did not pass in a tenantID then the default value is set + tenantID = "organizations" + } + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, tokenEndpoint) + data := url.Values{} + data.Set(qpGrantType, deviceCodeGrantType) + data.Set(qpClientID, clientID) + data.Set(qpDeviceCode, deviceCode) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err := req.SetBody(body) + if err != nil { + return nil, err + } + return req, nil +} + +func (c *aadIdentityClient) requestNewDeviceCode(ctx context.Context, tenantID, clientID string, scopes []string) (*deviceCodeResult, error) { + msg, err := c.createDeviceCodeNumberRequest(tenantID, clientID, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return createDeviceCodeResult(resp) + } + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +func (c *aadIdentityClient) createDeviceCodeNumberRequest(tenantID string, clientID string, scopes []string) (*azcore.Request, error) { + if len(tenantID) == 0 { // if the user did not pass in a tenantID then the default value is set + tenantID = "organizations" + } + u := *c.options.AuthorityHost + u.Path = path.Join(u.Path, tenantID, "/oauth2/v2.0/devicecode") // endpoint that will return a device code along with the other necessary authentication flow parameters in the DeviceCodeResult struct + data := url.Values{} + data.Set(qpClientID, clientID) + data.Set(qpScope, strings.Join(scopes, " ")) + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + req := azcore.NewRequest(http.MethodPost, u) + req.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + err := req.SetBody(body) + if err != nil { + return nil, err + } + return req, nil +} + +func getPrivateKey(cert string) (*rsa.PrivateKey, error) { + privateKeyFile, err := os.Open(cert) + if err != nil { + return nil, fmt.Errorf("Opening certificate file path: %w", err) + } + defer privateKeyFile.Close() + + pemFileInfo, err := privateKeyFile.Stat() + if err != nil { + return nil, fmt.Errorf("Getting certificate file info: %w", err) + } + size := pemFileInfo.Size() + + pemBytes := make([]byte, size) + buffer := bufio.NewReader(privateKeyFile) + _, err = buffer.Read(pemBytes) + if err != nil { + return nil, fmt.Errorf("Read PEM file bytes: %w", err) + } + + data, rest := pem.Decode([]byte(pemBytes)) + const privateKeyBlock = "PRIVATE KEY" + // NOTE: check types of private keys + if data.Type != privateKeyBlock { + for len(rest) > 0 { + data, rest = pem.Decode(rest) + if data.Type == privateKeyBlock { + privateKeyImported, err := x509.ParsePKCS8PrivateKey(data.Bytes) + if err != nil { + return nil, fmt.Errorf("ParsePKCS8PrivateKey: %w", err) + } + + return privateKeyImported.(*rsa.PrivateKey), nil + } + } + return nil, errors.New("Cannot find PRIVATE KEY in file") + } + // NOTE: this could be a function local closure + privateKeyImported, err := x509.ParsePKCS8PrivateKey(data.Bytes) + if err != nil { + return nil, fmt.Errorf("ParsePKCS8PrivateKey: %w", err) + } + + return privateKeyImported.(*rsa.PrivateKey), nil +} diff --git a/sdk/azidentity/aad_identity_client_test.go b/sdk/azidentity/aad_identity_client_test.go new file mode 100644 index 000000000000..92d942b28237 --- /dev/null +++ b/sdk/azidentity/aad_identity_client_test.go @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "net/url" + "testing" +) + +func TestAzurePublicCloudParse(t *testing.T) { + _, err := url.Parse(AzurePublicCloud) + if err != nil { + t.Fatalf("Failed to parse default authority host: %v", err) + } +} + +func TestAzureChinaParse(t *testing.T) { + _, err := url.Parse(AzureChina) + if err != nil { + t.Fatalf("Failed to parse AzureChina authority host: %v", err) + } +} + +func TestAzureGermanyParse(t *testing.T) { + _, err := url.Parse(AzureGermany) + if err != nil { + t.Fatalf("Failed to parse AzureGermany authority host: %v", err) + } +} + +func TestAzureGovernmentParse(t *testing.T) { + _, err := url.Parse(AzureGovernment) + if err != nil { + t.Fatalf("Failed to parse AzureGovernment authority host: %v", err) + } +} diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go new file mode 100644 index 000000000000..fbcae97f81e5 --- /dev/null +++ b/sdk/azidentity/azidentity.go @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "net/http" + "net/url" + "os" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + // AzureChina is a global constant to use in order to access the Azure China cloud. + AzureChina = "https://login.chinacloudapi.cn/" + // AzureGermany is a global constant to use in order to access the Azure Germany cloud. + AzureGermany = "https://login.microsoftonline.de/" + // AzureGovernment is a global constant to use in order to access the Azure Government cloud. + AzureGovernment = "https://login.microsoftonline.us/" + // AzurePublicCloud is a global constant to use in order to access the Azure public cloud. + AzurePublicCloud = "https://login.microsoftonline.com/" + // defaultSuffix is a suffix the signals that a string is in scope format + defaultSuffix = "/.default" +) + +var ( + successStatusCodes = [2]int{ + http.StatusOK, // 200 + http.StatusCreated, // 201 + } +) + +type tokenResponse struct { + token *azcore.AccessToken + refreshToken string +} + +// AADAuthenticationFailedError is used to unmarshal error responses received from Azure Active Directory. +type AADAuthenticationFailedError struct { + Message string `json:"error"` + Description string `json:"error_description"` + Timestamp string `json:"timestamp"` + TraceID string `json:"trace_id"` + CorrelationID string `json:"correlation_id"` + URI string `json:"error_uri"` + Response *azcore.Response +} + +func (e *AADAuthenticationFailedError) Error() string { + msg := e.Message + if len(e.Description) > 0 { + msg += " " + e.Description + } + return msg +} + +// AuthenticationFailedError is returned when the authentication request has failed. +type AuthenticationFailedError struct { + inner error + msg string +} + +// Unwrap method on AuthenticationFailedError provides access to the inner error if available. +func (e *AuthenticationFailedError) Unwrap() error { + return e.inner +} + +// IsNotRetriable returns true indicating that this is a terminal error. +func (e *AuthenticationFailedError) IsNotRetriable() bool { + return true +} + +func (e *AuthenticationFailedError) Error() string { + if len(e.msg) == 0 { + e.msg = e.inner.Error() + } + return e.msg +} + +func newAADAuthenticationFailedError(resp *azcore.Response) error { + authFailed := &AADAuthenticationFailedError{Response: resp} + err := resp.UnmarshalAsJSON(authFailed) + if err != nil { + authFailed.Message = resp.Status + authFailed.Description = "Failed to unmarshal response: " + err.Error() + } + return authFailed +} + +// CredentialUnavailableError is the error type returned when the conditions required to +// create a credential do not exist or are unavailable. +type CredentialUnavailableError struct { + // CredentialType holds the name of the credential that is unavailable + CredentialType string + // Message contains the reason why the credential is unavailable + Message string +} + +func (e *CredentialUnavailableError) Error() string { + return e.CredentialType + ": " + e.Message +} + +// IsNotRetriable returns true indicating that this is a terminal error. +func (e *CredentialUnavailableError) IsNotRetriable() bool { + return true +} + +// TokenCredentialOptions are used to configure how requests are made to Azure Active Directory. +type TokenCredentialOptions struct { + // The host of the Azure Active Directory authority. The default is https://login.microsoft.com + AuthorityHost *url.URL + + // HTTPClient sets the transport for making HTTP requests + // Leave this as nil to use the default HTTP transport + HTTPClient azcore.Transport + + // LogOptions configures the built-in request logging policy behavior + LogOptions azcore.RequestLogOptions + + // Retry configures the built-in retry policy behavior + Retry *azcore.RetryOptions + + // Telemetry configures the built-in telemetry policy behavior + Telemetry azcore.TelemetryOptions +} + +// setDefaultValues initializes an instance of TokenCredentialOptions with default settings. +func (c *TokenCredentialOptions) setDefaultValues() (*TokenCredentialOptions, error) { + authorityHost := AzurePublicCloud + if envAuthorityHost := os.Getenv("AZURE_AUTHORITY_HOST"); envAuthorityHost != "" { + authorityHost = envAuthorityHost + } + + if c == nil { + defaultAuthorityHostURL, err := url.Parse(authorityHost) + if err != nil { + return nil, err + } + c = &TokenCredentialOptions{AuthorityHost: defaultAuthorityHostURL} + } + + if c.AuthorityHost == nil { + defaultAuthorityHostURL, err := url.Parse(authorityHost) + if err != nil { + return nil, err + } + c.AuthorityHost = defaultAuthorityHostURL + } + + if len(c.AuthorityHost.Path) == 0 || c.AuthorityHost.Path[len(c.AuthorityHost.Path)-1:] != "/" { + c.AuthorityHost.Path = c.AuthorityHost.Path + "/" + } + + return c, nil +} + +// newDefaultPipeline creates a pipeline using the specified pipeline options. +func newDefaultPipeline(o TokenCredentialOptions) azcore.Pipeline { + if o.HTTPClient == nil { + o.HTTPClient = azcore.DefaultHTTPClientTransport() + } + + return azcore.NewPipeline( + o.HTTPClient, + azcore.NewTelemetryPolicy(o.Telemetry), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(o.Retry), + azcore.NewRequestLogPolicy(o.LogOptions)) +} + +// newDefaultMSIPipeline creates a pipeline using the specified pipeline options needed +// for a Managed Identity, such as a MSI specific retry policy. +func newDefaultMSIPipeline(o ManagedIdentityCredentialOptions) azcore.Pipeline { + if o.HTTPClient == nil { + o.HTTPClient = azcore.DefaultHTTPClientTransport() + } + var statusCodes []int + // retry policy for MSI is not end-user configurable + retryOpts := azcore.RetryOptions{ + MaxRetries: 4, + RetryDelay: 2 * time.Second, + TryTimeout: 1 * time.Minute, + StatusCodes: append(statusCodes, + // The following status codes are a subset of those found in azcore.StatusCodesForRetry, these are the only ones specifically needed for MSI scenarios + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusGatewayTimeout, // 504 + http.StatusNotFound, + http.StatusGone, + // all remaining 5xx + http.StatusNotImplemented, + http.StatusHTTPVersionNotSupported, + http.StatusVariantAlsoNegotiates, + http.StatusInsufficientStorage, + http.StatusLoopDetected, + http.StatusNotExtended, + http.StatusNetworkAuthenticationRequired), + } + + return azcore.NewPipeline( + o.HTTPClient, + azcore.NewTelemetryPolicy(o.Telemetry), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(&retryOpts), + azcore.NewRequestLogPolicy(o.LogOptions)) +} diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go new file mode 100644 index 000000000000..433203750481 --- /dev/null +++ b/sdk/azidentity/azidentity_test.go @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "net/url" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + envHostString = "https://mock.com/" + customHostString = "https://custommock.com/" +) + +func Test_AuthorityHost_Parse(t *testing.T) { + _, err := url.Parse(AzurePublicCloud) + if err != nil { + t.Fatalf("Failed to parse default authority host: %v", err) + } +} + +func Test_NonNilTokenCredentialOptsNilAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{Retry: &azcore.RetryOptions{MaxRetries: 6}} + opts, err := opts.setDefaultValues() + if err != nil { + t.Fatalf("Received an error: %v", err) + } + if opts.AuthorityHost == nil { + t.Fatalf("Did not set default authority host") + } +} + +func Test_SetEnvAuthorityHost(t *testing.T) { + err := os.Setenv("AZURE_AUTHORITY_HOST", envHostString) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + + opts := &TokenCredentialOptions{} + opts, err = opts.setDefaultValues() + if opts.AuthorityHost.String() != envHostString { + t.Fatalf("Unexpected error when get host from environment vairable: %v", err) + } + + // Unset that host environment vairable to avoid other tests failed. + err = os.Unsetenv("AZURE_AUTHORITY_HOST") + if err != nil { + t.Fatalf("Unexpected error when unset environment vairable: %v", err) + } +} + +func Test_CustomAuthorityHost(t *testing.T) { + err := os.Setenv("AZURE_AUTHORITY_HOST", envHostString) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + + customHost, err := url.Parse(customHostString) + if err != nil { + t.Fatalf("Received an error: %v", err) + } + + opts := &TokenCredentialOptions{AuthorityHost: customHost} + opts, err = opts.setDefaultValues() + if opts.AuthorityHost.String() != customHostString { + t.Fatalf("Unexpected error when get host from environment vairable: %v", err) + } + + // Unset that host environment vairable to avoid other tests failed. + err = os.Unsetenv("AZURE_AUTHORITY_HOST") + if err != nil { + t.Fatalf("Unexpected error when unset environment vairable: %v", err) + } +} + +func Test_DefaultAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{} + opts, err := opts.setDefaultValues() + if opts.AuthorityHost.String() != AzurePublicCloud { + t.Fatalf("Unexpected error when set default AuthorityHost: %v", err) + } +} + +func Test_AzureGermanyAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{} + opts, err := opts.setDefaultValues() + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(AzureGermany) + if err != nil { + t.Fatal(err) + } + opts.AuthorityHost = u + if opts.AuthorityHost.String() != AzureGermany { + t.Fatalf("Did not retrieve expected authority host string") + } +} + +func Test_AzureChinaAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{} + opts, err := opts.setDefaultValues() + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(AzureChina) + if err != nil { + t.Fatal(err) + } + opts.AuthorityHost = u + if opts.AuthorityHost.String() != AzureChina { + t.Fatalf("Did not retrieve expected authority host string") + } +} + +func Test_AzureGovernmentAuthorityHost(t *testing.T) { + opts := &TokenCredentialOptions{} + opts, err := opts.setDefaultValues() + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(AzureGovernment) + if err != nil { + t.Fatal(err) + } + opts.AuthorityHost = u + if opts.AuthorityHost.String() != AzureGovernment { + t.Fatalf("Did not retrieve expected authority host string") + } +} diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go new file mode 100644 index 000000000000..220109ab5453 --- /dev/null +++ b/sdk/azidentity/azure_cli_credential.go @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "regexp" + "runtime" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// AzureCLITokenProvider can be used to supply the AzureCLICredential with an alternate token provider +type AzureCLITokenProvider func(ctx context.Context, resource string) ([]byte, error) + +// AzureCLICredentialOptions contains options used to configure the AzureCLICredential +type AzureCLICredentialOptions struct { + TokenProvider AzureCLITokenProvider +} + +// AzureCLICredential enables authentication to Azure Active Directory using the Azure CLI command "az account get-access-token". +type AzureCLICredential struct { + tokenProvider AzureCLITokenProvider +} + +// NewAzureCLICredential constructs a new AzureCLICredential with the details needed to authenticate against Azure Active Directory +// options: configure the management of the requests sent to Azure Active Directory. +func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredential, error) { + if options == nil { + options = &AzureCLICredentialOptions{TokenProvider: defaultTokenProvider()} + } + return &AzureCLICredential{ + tokenProvider: options.TokenProvider, + }, nil +} + +// GetToken obtains a token from Azure Active Directory, using the Azure CLI command to authenticate. +// ctx: Context used to control the request lifetime. +// opts: TokenRequestOptions contains the list of scopes for which the token will have access. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *AzureCLICredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + // The following code will remove the /.default suffix from the scope passed into the method since AzureCLI expect a resource string instead of a scope string + opts.Scopes[0] = strings.TrimSuffix(opts.Scopes[0], defaultSuffix) + at, err := c.authenticate(ctx, opts.Scopes[0]) + if err != nil { + addGetTokenFailureLogs("Azure CLI Credential", err) + return nil, err + } + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return at, nil +} + +// AuthenticationPolicy implements the azcore.Credential interface on AzureCLICredential and calls the Bearer Token policy +// to get the bearer token. +func (c *AzureCLICredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} + +const timeoutCLIRequest = 10000 * time.Millisecond + +// authenticate creates a client secret authentication request and returns the resulting Access Token or +// an error in case of authentication failure. +// ctx: The current request context +// scopes: The scopes for which the token has access +func (c *AzureCLICredential) authenticate(ctx context.Context, resource string) (*azcore.AccessToken, error) { + output, err := c.tokenProvider(ctx, resource) + if err != nil { + return nil, err + } + + return c.createAccessToken(output) +} + +func defaultTokenProvider() func(ctx context.Context, resource string) ([]byte, error) { + return func(ctx context.Context, resource string) ([]byte, error) { + // This is the path that a developer can set to tell this class what the install path for Azure CLI is. + const azureCLIPath = "AZURE_CLI_PATH" + + // The default install paths are used to find Azure CLI. This is for security, so that any path in the calling program's Path environment is not used to execute Azure CLI. + azureCLIDefaultPathWindows := fmt.Sprintf("%s\\Microsoft SDKs\\Azure\\CLI2\\wbin; %s\\Microsoft SDKs\\Azure\\CLI2\\wbin", os.Getenv("ProgramFiles(x86)"), os.Getenv("ProgramFiles")) + + // Default path for non-Windows. + const azureCLIDefaultPath = "/bin:/sbin:/usr/bin:/usr/local/bin" + + // Validate resource, since it gets sent as a command line argument to Azure CLI + const invalidResourceErrorTemplate = "Resource %s is not in expected format. Only alphanumeric characters, [dot], [colon], [hyphen], and [forward slash] are allowed." + match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource) + if err != nil { + return nil, err + } + if !match { + return nil, fmt.Errorf(invalidResourceErrorTemplate, resource) + } + + ctx, cancel := context.WithTimeout(ctx, timeoutCLIRequest) + defer cancel() + + // Execute Azure CLI to get token + var cliCmd *exec.Cmd + if runtime.GOOS == "windows" { + cliCmd = exec.CommandContext(ctx, fmt.Sprintf("%s\\system32\\cmd.exe", os.Getenv("windir"))) + cliCmd.Env = os.Environ() + cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s;%s", os.Getenv(azureCLIPath), azureCLIDefaultPathWindows)) + cliCmd.Args = append(cliCmd.Args, "/c", "az") + } else { + cliCmd = exec.CommandContext(ctx, "az") + cliCmd.Env = os.Environ() + cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s:%s", os.Getenv(azureCLIPath), azureCLIDefaultPath)) + } + cliCmd.Args = append(cliCmd.Args, "account", "get-access-token", "-o", "json", "--resource", resource) + + var stderr bytes.Buffer + cliCmd.Stderr = &stderr + + output, err := cliCmd.Output() + if err != nil { + return nil, &CredentialUnavailableError{CredentialType: "Azure CLI Credential", Message: stderr.String()} + } + + return output, nil + } +} + +func (c *AzureCLICredential) createAccessToken(tk []byte) (*azcore.AccessToken, error) { + t := struct { + AccessToken string `json:"accessToken"` + Authority string `json:"_authority"` + ClientID string `json:"_clientId"` + ExpiresOn string `json:"expiresOn"` + IdentityProvider string `json:"identityProvider"` + IsMRRT bool `json:"isMRRT"` + RefreshToken string `json:"refreshToken"` + Resource string `json:"resource"` + TokenType string `json:"tokenType"` + UserID string `json:"userId"` + }{} + err := json.Unmarshal(tk, &t) + if err != nil { + return nil, err + } + + tokenExpirationDate, err := parseExpirationDate(t.ExpiresOn) + if err != nil { + return nil, fmt.Errorf("Error parsing Token Expiration Date %q: %+v", t.ExpiresOn, err) + } + + converted := &azcore.AccessToken{ + Token: t.AccessToken, + ExpiresOn: *tokenExpirationDate, + } + return converted, nil +} + +// parseExpirationDate parses either a Azure CLI or CloudShell date into a time object +func parseExpirationDate(input string) (*time.Time, error) { + // CloudShell (and potentially the Azure CLI in future) + expirationDate, cloudShellErr := time.Parse(time.RFC3339, input) + if cloudShellErr != nil { + // Azure CLI (Python) e.g. 2017-08-31 19:48:57.998857 (plus the local timezone) + const cliFormat = "2006-01-02 15:04:05.999999" + expirationDate, cliErr := time.ParseInLocation(cliFormat, input, time.Local) + if cliErr != nil { + return nil, fmt.Errorf("Error parsing expiration date %q.\n\nCloudShell Error: \n%+v\n\nCLI Error:\n%+v", input, cloudShellErr, cliErr) + } + return &expirationDate, nil + } + return &expirationDate, nil +} diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go new file mode 100644 index 000000000000..f99c993afbfb --- /dev/null +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +var ( + mockCLITokenProviderSuccess = func(ctx context.Context, resource string) ([]byte, error) { + return []byte(" {\"accessToken\":\"mocktoken\" , " + + "\"expiresOn\": \"2007-01-01 01:01:01.079627\"," + + "\"subscription\": \"mocksub\"," + + "\"tenant\": \"mocktenant\"," + + "\"tokenType\": \"mocktype\"}"), nil + } + mockCLITokenProviderFailure = func(ctx context.Context, resource string) ([]byte, error) { + return nil, errors.New("provider failure message") + } +) + +func TestAzureCLICredential_GetTokenSuccess(t *testing.T) { + cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TokenProvider: mockCLITokenProviderSuccess}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + at, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %v", err) + } + if len(at.Token) == 0 { + t.Fatalf(("Did not receive a token")) + } + if at.Token != "mocktoken" { + t.Fatalf(("Did not receive the correct access token")) + } +} + +func TestAzureCLICredential_GetTokenInvalidToken(t *testing.T) { + cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TokenProvider: mockCLITokenProviderFailure}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one.") + } +} + +func TestBearerPolicy_AzureCLICredential(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TokenProvider: mockCLITokenProviderSuccess}) + if err != nil { + t.Fatalf("Did not expect an error but received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatal("Expected nil error but received one") + } +} diff --git a/sdk/azidentity/bearer_token_policy.go b/sdk/azidentity/bearer_token_policy.go new file mode 100644 index 000000000000..e00d85d51524 --- /dev/null +++ b/sdk/azidentity/bearer_token_policy.go @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + bearerTokenPrefix = "Bearer " +) + +type bearerTokenPolicy struct { + // cond is used to synchronize token refresh. the locker + // must be locked when updating the following shared state. + cond *sync.Cond + + // renewing indicates that the token is in the process of being refreshed + renewing bool + + // header contains the authorization header value + header string + + // expiresOn is when the token will expire + expiresOn time.Time + + // the following fields are read-only + creds azcore.TokenCredential + options azcore.TokenRequestOptions +} + +func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationPolicyOptions) *bearerTokenPolicy { + return &bearerTokenPolicy{ + cond: sync.NewCond(&sync.Mutex{}), + creds: creds, + options: opts.Options, + } +} + +func (b *bearerTokenPolicy) Do(ctx context.Context, req *azcore.Request) (*azcore.Response, error) { + if req.URL.Scheme != "https" { + // HTTPS must be used, otherwise the tokens are at the risk of being exposed + return nil, &AuthenticationFailedError{msg: "token credentials require a URL using the HTTPS protocol scheme"} + } + // create a "refresh window" before the token's real expiration date. + // this allows callers to continue to use the old token while the + // refresh is in progress. + const window = 2 * time.Minute + now, getToken, header := time.Now(), false, "" + // acquire exclusive lock + b.cond.L.Lock() + for { + if b.expiresOn.IsZero() || b.expiresOn.Before(now) { + // token was never obtained or has expired + if !b.renewing { + // another go routine isn't refreshing the token so this one will + b.renewing = true + getToken = true + break + } + // getting here means this go routine will wait for the token to refresh + } else if b.expiresOn.Add(-window).Before(now) { + // token is within the expiration window + if !b.renewing { + // another go routine isn't refreshing the token so this one will + b.renewing = true + getToken = true + break + } + // this go routine will use the existing token while another refreshes it + header = b.header + break + } else { + // token is not expiring yet so use it as-is + header = b.header + break + } + // wait for the token to refresh + b.cond.Wait() + } + b.cond.L.Unlock() + if getToken { + // this go routine has been elected to refresh the token + tk, err := b.creds.GetToken(ctx, b.options) + if err != nil { + return nil, err + } + header = bearerTokenPrefix + tk.Token + // update shared state + b.cond.L.Lock() + b.renewing = false + b.header = header + b.expiresOn = tk.ExpiresOn + // signal any waiters that the token has been refreshed + b.cond.Broadcast() + b.cond.L.Unlock() + } + req.Request.Header.Set(azcore.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Request.Header.Set(azcore.HeaderAuthorization, header) + return req.Next(ctx) +} diff --git a/sdk/azidentity/bearer_token_policy_test.go b/sdk/azidentity/bearer_token_policy_test.go new file mode 100644 index 000000000000..cbe167e62c38 --- /dev/null +++ b/sdk/azidentity/bearer_token_policy_test.go @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + accessTokenRespError = `{"error": "invalid_client","error_description": "Invalid client secret is provided.","error_codes": [0],"timestamp": "2019-12-01 19:00:00Z","trace_id": "2d091b0","correlation_id": "a999","error_uri": "https://login.contoso.com/error?code=0"}` + accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}` + accessTokenRespMalformed = `{"access_token": 0, "expires_in": 3600}` + accessTokenRespShortLived = `{"access_token": "` + tokenValue + `", "expires_in": 0}` +) + +func TestBearerPolicy_SuccessGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + resp, err := pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatalf("Expected nil error but received one") + } + const expectedToken = bearerTokenPrefix + tokenValue + if token := resp.Request.Header.Get(azcore.HeaderAuthorization); token != expectedToken { + t.Fatalf("expected token '%s', got '%s'", expectedToken, token) + } +} + +func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + resp, err := pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + var afe *AuthenticationFailedError + if !errors.As(err, &afe) { + t.Fatalf("unexpected error type %v", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestBearerTokenPolicy_TokenExpired(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + req := azcore.NewRequest(http.MethodGet, srv.URL()) + _, err = pipeline.Do(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + _, err = pipeline.Do(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +// with https scheme enabled we get an auth failed error which let's us test the is not retriable error +func TestRetryPolicy_IsNotRetriable(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + var afe *AuthenticationFailedError + if !errors.As(err, &afe) { + t.Fatalf("unexpected error type %v", err) + } +} + +func TestRetryPolicy_HTTPRequest(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + var afe *AuthenticationFailedError + if !errors.As(err, &afe) { + t.Fatalf("unexpected error type %v", err) + } +} diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go new file mode 100644 index 000000000000..b8bb15b00711 --- /dev/null +++ b/sdk/azidentity/chained_token_credential.go @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// ChainedTokenCredential provides a TokenCredential implementation that chains multiple TokenCredential sources to be tried in order +// and returns the token from the first successful call to GetToken(). +type ChainedTokenCredential struct { + sources []azcore.TokenCredential +} + +// NewChainedTokenCredential creates an instance of ChainedTokenCredential with the specified TokenCredential sources. +func NewChainedTokenCredential(sources ...azcore.TokenCredential) (*ChainedTokenCredential, error) { + if len(sources) == 0 { + credErr := &CredentialUnavailableError{CredentialType: "Chained Token Credential", Message: "Length of sources cannot be 0"} + azcore.Log().Write(azcore.LogError, logCredentialError(credErr.CredentialType, credErr)) + return nil, credErr + } + for _, source := range sources { + if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil + credErr := &CredentialUnavailableError{CredentialType: "Chained Token Credential", Message: "Sources cannot contain a nil TokenCredential"} + azcore.Log().Write(azcore.LogError, logCredentialError(credErr.CredentialType, credErr)) + return nil, credErr + } + } + return &ChainedTokenCredential{sources: sources}, nil +} + +// GetToken sequentially calls TokenCredential.GetToken on all the specified sources, returning the token from the first successful call to GetToken(). +func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (token *azcore.AccessToken, err error) { + var errList []*CredentialUnavailableError + for _, cred := range c.sources { // loop through all of the credentials provided in sources + token, err = cred.GetToken(ctx, opts) // make a GetToken request for the current credential in the loop + var credErr *CredentialUnavailableError + if errors.As(err, &credErr) { // check if we received a CredentialUnavailableError + errList = append(errList, credErr) // if we did receive a CredentialUnavailableError then we append it to our error slice and continue looping for a good credential + } else if err != nil { // if we receive some other type of error then we must stop looping and process the error accordingly + var authenticationFailed *AuthenticationFailedError + if errors.As(err, &authenticationFailed) { // if the error is an AuthenticationFailedError we return the error related to the invalid credential and append all of the other error messages received prior to this point + authErr := &AuthenticationFailedError{msg: "Received an AuthenticationFailedError, there is an invalid credential in the chain. " + createChainedErrorMessage(errList), inner: err} + addGetTokenFailureLogs("Chained Token Credential", authErr) + return nil, authErr + } + addGetTokenFailureLogs("Chained Token Credential", err) + return nil, err // if we receive some other error type this is unexpected and we simple return the unexpected error + } else { + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return token, nil // if we did not receive an error then we return the token + } + } + // if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableErrors + credErr := &CredentialUnavailableError{CredentialType: "Chained Token Credential", Message: createChainedErrorMessage(errList)} + addGetTokenFailureLogs("Chained Token Credential", credErr) + return nil, credErr +} + +// AuthenticationPolicy implements the azcore.Credential interface on ChainedTokenCredential and sets the bearer token +func (c *ChainedTokenCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} + +// helper function used to chain the error messages of the CredentialUnavailableError slice +func createChainedErrorMessage(errList []*CredentialUnavailableError) string { + msg := "" + for _, err := range errList { + msg += err.Error() + } + + return msg +} diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go new file mode 100644 index 000000000000..96fd678bcacc --- /dev/null +++ b/sdk/azidentity/chained_token_credential_test.go @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestChainedTokenCredential_InstantiateSuccess(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Could not set environment variables for testing: %v", err) + } + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + envCred, err := NewEnvironmentCredential(nil) + if err != nil { + t.Fatalf("Could not find appropriate environment credentials") + } + cred, err := NewChainedTokenCredential(secCred, envCred) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cred != nil { + if len(cred.sources) != 2 { + t.Fatalf("Expected 2 sources in the chained token credential, instead found %d", len(cred.sources)) + } + } +} + +func TestChainedTokenCredential_InstantiateFailure(t *testing.T) { + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = NewChainedTokenCredential(secCred, nil) + if err == nil { + t.Fatalf("Expected an error for sending a nil credential in the chain") + } + var credErr *CredentialUnavailableError + if !errors.As(err, &credErr) { + t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr) + } + _, err = NewChainedTokenCredential() + if err == nil { + t.Fatalf("Expected an error for not sending any credential sources") + } + if !errors.As(err, &credErr) { + t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr) + } +} + +func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Could not set environment variables for testing: %v", err) + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + envCred, err := NewEnvironmentCredential(&TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Failed to create environment credential: %v", err) + } + cred, err := NewChainedTokenCredential(secCred, envCred) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } +} + +func TestChainedTokenCredential_GetTokenFail(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + testURL := srv.URL() + secCred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &testURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + cred, err := NewChainedTokenCredential(secCred) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } + var authErr *AuthenticationFailedError + if !errors.As(err, &authErr) { + t.Fatalf("Expected Error Type: AuthenticationFailedError, ReceivedErrorType: %T", err) + } + if len(err.Error()) == 0 { + t.Fatalf("Did not create an appropriate error message") + } +} + +func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendError(&CredentialUnavailableError{CredentialType: "MockCredential", Message: "Mocking a credential unavailable error"}) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + testURL := srv.URL() + secCred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &testURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + // The chain has the same credential twice, since it doesn't matter what credential we add to the chain as long as it is not a nil credential. + // Most credentials will not be instantiated if the conditions do not exist to allow them to be used, thus returning a + // CredentialUnavailable error from the constructor. In order to test the CredentialUnavailable functionality for + // ChainedTokenCredential we have to mock with two valid credentials, but the first will fail since the first response queued + // in the test server is a CredentialUnavailable error. + cred, err := NewChainedTokenCredential(secCred, secCred) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } +} + +func TestBearerPolicy_ChainedTokenCredential(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to initialize environment variables. Received: %v", err) + } + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + chainedCred, err := NewChainedTokenCredential(cred) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + chainedCred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatalf("Expected an empty error but receive: %v", err) + } +} diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go new file mode 100644 index 000000000000..2626b09f7331 --- /dev/null +++ b/sdk/azidentity/client_certificate_credential.go @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// ClientCertificateCredential enables authentication of a service principal to Azure Active Directory using a certificate that is assigned to its App Registration. More information +// on how to configure certificate authentication can be found here: +// https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials#register-your-certificate-with-azure-ad +type ClientCertificateCredential struct { + client *aadIdentityClient + tenantID string // The Azure Active Directory tenant (directory) ID of the service principal + clientID string // The client (application) ID of the service principal + clientCertificate string // Path to the client certificate generated for the App Registration used to authenticate the client +} + +// NewClientCertificateCredential creates an instance of ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal. +// clientID: The client (application) ID of the service principal. +// clientCertificate: The path to the client certificate that was generated for the App Registration used to authenticate the client. +// options: configure the management of the requests sent to Azure Active Directory. +func NewClientCertificateCredential(tenantID string, clientID string, clientCertificate string, options *TokenCredentialOptions) (*ClientCertificateCredential, error) { + _, err := os.Stat(clientCertificate) + if err != nil { + credErr := &CredentialUnavailableError{CredentialType: "Client Certificate Credential", Message: "Certificate file not found in path: " + clientCertificate} + azcore.Log().Write(azcore.LogError, logCredentialError(credErr.CredentialType, credErr)) + return nil, credErr + } + c, err := newAADIdentityClient(options) + if err != nil { + return nil, err + } + return &ClientCertificateCredential{tenantID: tenantID, clientID: clientID, clientCertificate: clientCertificate, client: c}, nil +} + +// GetToken obtains a token from Azure Active Directory, using the certificate in the file path. +// scopes: The list of scopes for which the token will have access. +// ctx: controlling the request lifetime. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + tk, err := c.client.authenticateCertificate(ctx, c.tenantID, c.clientID, c.clientCertificate, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Client Certificate Credential", err) + return nil, err + } + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return tk, nil +} + +// AuthenticationPolicy implements the azcore.Credential interface on ClientSecretCredential. +func (c *ClientCertificateCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} diff --git a/sdk/azidentity/client_certificate_credential_test.go b/sdk/azidentity/client_certificate_credential_test.go new file mode 100644 index 000000000000..fe7f60d47f87 --- /dev/null +++ b/sdk/azidentity/client_certificate_credential_test.go @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + certificatePath = "testdata/certificate.pem" + wrongCertificatePath = "wrong_certificate_path.pem" +) + +func TestClientCertificateCredential_CreateAuthRequestSuccess(t *testing.T) { + cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, nil) + if err != nil { + t.Fatalf("Failed to instantiate credential") + } + req, err := cred.client.createClientCertificateAuthRequest(cred.tenantID, cred.clientID, cred.clientCertificate, []string{scope}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpGrantType][0] != "client_credentials" { + t.Fatalf("Wrong grant type in request body") + } + if reqQueryParams[qpClientAssertionType][0] != clientAssertionType { + t.Fatalf("Wrong client assertion type assigned to request") + } + if reqQueryParams[qpScope][0] != scope { + t.Fatalf("Unexpected scope in scope header") + } + if len(reqQueryParams[qpClientAssertion][0]) == 0 { + t.Fatalf("Client assertion is not present on the request") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } +} + +func TestClientCertificateCredential_GetTokenSuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } +} + +func TestClientCertificateCredential_GetTokenInvalidCredentials(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Did not expect an error but received one: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected to receive a nil error, but received: %v", err) + } + var authFailed *AuthenticationFailedError + if !errors.As(err, &authFailed) { + t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) + } +} + +func TestClientCertificateCredential_WrongCertificatePath(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + _, err := NewClientCertificateCredential(tenantID, clientID, wrongCertificatePath, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } +} + +func TestClientCertificateCredential_GetTokenCheckPrivateKeyBlocks(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_formatB.pem", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } +} + +func TestClientCertificateCredential_GetTokenCheckCertificateBlocks(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_formatA.pem", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } +} + +func TestClientCertificateCredential_GetTokenEmptyCertificate(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_empty.pem", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but received nil") + } +} + +func TestClientCertificateCredential_GetTokenNoPrivateKey(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_nokey.pem", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but received nil") + } +} + +func TestBearerPolicy_ClientCertificateCredential(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Did not expect an error but received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatalf("Expected nil error but received one") + } +} diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go new file mode 100644 index 000000000000..eb0e39aa94e5 --- /dev/null +++ b/sdk/azidentity/client_secret_credential.go @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// ClientSecretCredential enables authentication to Azure Active Directory using a client secret that was generated for an App Registration. More information on how +// to configure a client secret can be found here: +// https://docs.microsoft.com/en-us/azure/active-directory/develop/quickstart-configure-app-access-web-apis#add-credentials-to-your-web-application +type ClientSecretCredential struct { + client *aadIdentityClient + tenantID string // Gets the Azure Active Directory tenant (directory) ID of the service principal + clientID string // Gets the client (application) ID of the service principal + clientSecret string // Gets the client secret that was generated for the App Registration used to authenticate the client. +} + +// NewClientSecretCredential constructs a new ClientSecretCredential with the details needed to authenticate against Azure Active Directory with a client secret. +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal. +// clientID: The client (application) ID of the service principal. +// clientSecret: A client secret that was generated for the App Registration used to authenticate the client. +// options: allow to configure the management of the requests sent to Azure Active Directory. +func NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *TokenCredentialOptions) (*ClientSecretCredential, error) { + c, err := newAADIdentityClient(options) + if err != nil { + return nil, err + } + return &ClientSecretCredential{tenantID: tenantID, clientID: clientID, clientSecret: clientSecret, client: c}, nil +} + +// GetToken obtains a token from Azure Active Directory, using the specified client secret to authenticate. +// ctx: Context used to control the request lifetime. +// opts: TokenRequestOptions contains the list of scopes for which the token will have access. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *ClientSecretCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + tk, err := c.client.authenticate(ctx, c.tenantID, c.clientID, c.clientSecret, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Client Secret Credential", err) + return nil, err + } + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return tk, nil +} + +// AuthenticationPolicy implements the azcore.Credential interface on ClientSecretCredential and calls the Bearer Token policy +// to get the bearer token. +func (c *ClientSecretCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} + +var _ azcore.TokenCredential = (*ClientSecretCredential)(nil) diff --git a/sdk/azidentity/client_secret_credential_test.go b/sdk/azidentity/client_secret_credential_test.go new file mode 100644 index 000000000000..007c9b695e61 --- /dev/null +++ b/sdk/azidentity/client_secret_credential_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + tenantID = "expected_tenant" + clientID = "expected_client" + secret = "secret" + wrongSecret = "wrong_secret" + tokenValue = "new_token" + scope = "http://storage.azure.com/.default" + defaultTestAuthorityHost = "login.microsoftonline.com" +) + +func TestClientSecretCredential_CreateAuthRequestSuccess(t *testing.T) { + cred, err := NewClientSecretCredential(tenantID, clientID, secret, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + req, err := cred.client.createClientSecretAuthRequest(cred.tenantID, cred.clientID, cred.clientSecret, []string{scope}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpClientSecret][0] != secret { + t.Fatalf("Unexpected secret in the client_secret header") + } + if reqQueryParams[qpScope][0] != scope { + t.Fatalf("Unexpected scope in scope header") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } +} + +func TestClientSecretCredential_GetTokenSuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %v", err) + } +} + +func TestClientSecretCredential_GetTokenInvalidCredentials(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithBody([]byte(accessTokenRespError)), mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, wrongSecret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one.") + } + var authFailed *AuthenticationFailedError + if !errors.As(err, &authFailed) { + t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) + } else { + var respError *AADAuthenticationFailedError + if !errors.As(authFailed.Unwrap(), &respError) { + t.Fatalf("Expected: AADAuthenticationFailedError, Received: %T", err) + } else { + if len(respError.Message) == 0 { + t.Fatalf("Did not receive an error message") + } + if len(respError.Description) == 0 { + t.Fatalf("Did not receive an error description") + } + if len(respError.Timestamp) == 0 { + t.Fatalf("Did not receive a timestamp") + } + if len(respError.TraceID) == 0 { + t.Fatalf("Did not receive a TraceID") + } + if len(respError.CorrelationID) == 0 { + t.Fatalf("Did not receive a CorrelationID") + } + if len(respError.URI) == 0 { + t.Fatalf("Did not receive an error URI") + } + if respError.Response == nil { + t.Fatalf("Did not receive an error response") + } + } + } +} + +func TestClientSecretCredential_GetTokenUnexpectedJSON(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespMalformed))) + srvURL := srv.URL() + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Failed to create the credential") + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected a JSON marshal error but received nil") + } +} diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go new file mode 100644 index 000000000000..1bcd0c0d4e01 --- /dev/null +++ b/sdk/azidentity/default_azure_credential.go @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + developerSignOnClientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" +) + +// DefaultAzureCredentialOptions contains options for configuring how credentials are acquired. +type DefaultAzureCredentialOptions struct { + // set this field to true in order to exclude the EnvironmentCredential from the set of + // credentials that will be used to authenticate with + ExcludeEnvironmentCredential bool + // set this field to true in order to exclude the ManagedIdentityCredential from the set of + // credentials that will be used to authenticate with + ExcludeMSICredential bool +} + +// NewDefaultAzureCredential provides a default ChainedTokenCredential configuration for applications that will be deployed to Azure. The following credential +// types will be tried, in the following order: +// - EnvironmentCredential +// - ManagedIdentityCredential +// Consult the documentation for these credential types for more information on how they attempt authentication. +func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*ChainedTokenCredential, error) { + var creds []azcore.TokenCredential + errMsg := "" + + if options == nil { + options = &DefaultAzureCredentialOptions{} + } + + if !options.ExcludeEnvironmentCredential { + envCred, err := NewEnvironmentCredential(nil) + if err == nil { + creds = append(creds, envCred) + } else { + errMsg += err.Error() + } + } + + if !options.ExcludeMSICredential { + msiCred, err := NewManagedIdentityCredential("", nil) + if err == nil { + creds = append(creds, msiCred) + } else { + errMsg += err.Error() + } + } + // if no credentials are added to the slice of TokenCredentials then return a CredentialUnavailableError + if len(creds) == 0 { + err := &CredentialUnavailableError{CredentialType: "Default Azure Credential", Message: errMsg} + azcore.Log().Write(azcore.LogError, logCredentialError(err.CredentialType, err)) + return nil, err + } + azcore.Log().Write(LogCredential, "Azure Identity => NewDefaultAzureCredential() invoking NewChainedTokenCredential()") + return NewChainedTokenCredential(creds...) +} diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go new file mode 100644 index 000000000000..37299019f594 --- /dev/null +++ b/sdk/azidentity/default_azure_credential_test.go @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "os" + "testing" +) + +func TestDefaultAzureCredential_ExcludeEnvCredential(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + _ = os.Setenv("MSI_ENDPOINT", "http://localhost:3000") + cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ExcludeEnvironmentCredential: true}) + if err != nil { + t.Fatalf("Did not expect to receive an error in creating the credential") + } + + if len(cred.sources) != 1 { + t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: 1, Received: %d", len(cred.sources)) + } + _ = os.Setenv("MSI_ENDPOINT", "") + +} + +func TestDefaultAzureCredential_ExcludeMSICredential(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ExcludeMSICredential: true}) + if err != nil { + t.Fatalf("Did not expect to receive an error in creating the credential") + } + if len(cred.sources) != 1 { + t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: 1, Received: %d", len(cred.sources)) + } + +} + +func TestDefaultAzureCredential_ExcludeAllCredentials(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + var credUnavailable *CredentialUnavailableError + _, err = NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ExcludeEnvironmentCredential: false, ExcludeMSICredential: true}) + if err == nil { + t.Fatalf("Expected an error but received nil") + } + if !errors.As(err, &credUnavailable) { + t.Fatalf("Expected: CredentialUnavailableError, Received: %T", err) + } + +} + +func TestDefaultAzureCredential_NilOptions(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + err = initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + cred, err := NewDefaultAzureCredential(nil) + if err != nil { + t.Fatalf("Did not expect to receive an error in creating the credential") + } + c := newManagedIdentityClient(nil) + // if the test is running in a MSI environment then the length of sources would be two since it will include environmnet credential and managed identity credential + if msiType, err := c.getMSIType(context.Background()); msiType == msiTypeIMDS || msiType == msiTypeCloudShell || msiType == msiTypeAppService { + if len(cred.sources) != 2 { + t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: 2, Received: %d", len(cred.sources)) + } + //if a credential unavailable error is received or msiType is unknown then only the environment credential will be added + } else if unavailableErr := (*CredentialUnavailableError)(nil); errors.As(err, &unavailableErr) || msiType == msiTypeUnknown { + if len(cred.sources) != 1 { + t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: 1, Received: %d", len(cred.sources)) + } + // if there is some other unexpected error then we fail here + } else if err != nil { + t.Fatalf("Received an error when trying to determine MSI type: %v", err) + } +} diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go new file mode 100644 index 000000000000..43c6767823d7 --- /dev/null +++ b/sdk/azidentity/device_code_credential.go @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" +) + +// DeviceCodeCredential authenticates a user using the device code flow, and provides access tokens for that user account. +// For more information on the device code authentication flow see: https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code. +type DeviceCodeCredential struct { + client *aadIdentityClient + tenantID string // Gets the Azure Active Directory tenant (directory) ID of the service principal + clientID string // Gets the client (application) ID of the service principal + callback func(string) // Sends the user a message with a verification URL and device code to sign in to the login server + refreshToken string // Gets the refresh token sent from the service and will be used to retreive new access tokens after the initial request for a token. Thread safety for updates is handled in the AuthenticationPolicy since only one goroutine will be updating at a time +} + +// NewDeviceCodeCredential constructs a new DeviceCodeCredential used to authenticate against Azure Active Directory with a device code. +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal. If none is set then the default value ("organizations") will be used in place of the tenantID. +// clientID: The client (application) ID of the service principal. +// callback: The callback function used to send the login message back to the user +// options: Options used to configure the management of the requests sent to Azure Active Directory. +func NewDeviceCodeCredential(tenantID string, clientID string, callback func(string), options *TokenCredentialOptions) (*DeviceCodeCredential, error) { + c, err := newAADIdentityClient(options) + if err != nil { + return nil, err + } + return &DeviceCodeCredential{tenantID: tenantID, clientID: clientID, callback: callback, client: c}, nil +} + +// GetToken obtains a token from Azure Active Directory, following the device code authentication +// flow. This function first requests a device code and requests that the user login before continuing to authenticate the device. +// This function will keep polling the service for a token until the user logs in. +// scopes: The list of scopes for which the token will have access. The "offline_access" scope is checked for and automatically added in case it isn't present to allow for silent token refresh. +// ctx: The context for controlling the request lifetime. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + for i, scope := range opts.Scopes { + if scope == "offline_access" { // if we find that the opts.Scopes slice contains "offline_access" then we don't need to do anything and exit + break + } + if i == len(opts.Scopes)-1 && scope != "offline_access" { // if we haven't found "offline_access" when reaching the last element in the slice then we append it + opts.Scopes = append(opts.Scopes, "offline_access") + } + } + if len(c.refreshToken) != 0 { + tk, err := c.client.refreshAccessToken(ctx, c.tenantID, c.clientID, "", c.refreshToken, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Device Code Credential", err) + return nil, err + } + // assign new refresh token to the credential for future use + c.refreshToken = tk.refreshToken + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + // passing the access token and/or error back up + return tk.token, nil + } + // if there is no refreshToken, then begin the Device Code flow from the beginning + // make initial request to the device code endpoint for a device code and instructions for authentication + dc, err := c.client.requestNewDeviceCode(ctx, c.tenantID, c.clientID, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Device Code Credential", err) + return nil, err // TODO check what error type to return here + } + // send authentication flow instructions back to the user to log in and authorize the device + c.callback(dc.Message) + // poll the token endpoint until a valid access token is received or until authentication fails + for { + tk, err := c.client.authenticateDeviceCode(ctx, c.tenantID, c.clientID, dc.DeviceCode, opts.Scopes) + // if there is no error, save the refresh token and return the token credential + if err == nil { + c.refreshToken = tk.refreshToken + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return tk.token, err + } + // if there is an error, check for an AADAuthenticationFailedError in order to check the status for token retrieval + // if the error is not an AADAuthenticationFailedError, then fail here since something unexpected occurred + if authRespErr := (*AADAuthenticationFailedError)(nil); errors.As(err, &authRespErr) && authRespErr.Message == "authorization_pending" { + // wait for the interval specified from the initial device code endpoint and then poll for the token again + time.Sleep(time.Duration(dc.Interval) * time.Second) + } else { + addGetTokenFailureLogs("Device Code Credential", err) + // any other error should be returned + return nil, err + } + } +} + +// AuthenticationPolicy implements the azcore.Credential interface on ClientSecretCredential. +func (c *DeviceCodeCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} + +// deviceCodeResult is used to store device code related information to help the user login and allow the device code flow to continue +// to request a token to authenticate a user +type deviceCodeResult struct { + UserCode string `json:"user_code"` // User code returned by the service + DeviceCode string `json:"device_code"` // Device code returned by the service + VerificationURL string `json:"verification_uri"` // Verification URL where the user must navigate to authenticate using the device code and credentials. + Interval int64 `json:"interval"` // Polling interval time to check for completion of authentication flow. + Message string `json:"message"` // User friendly text response that can be used for display purpose. +} diff --git a/sdk/azidentity/device_code_credential_test.go b/sdk/azidentity/device_code_credential_test.go new file mode 100644 index 000000000000..4e131e772452 --- /dev/null +++ b/sdk/azidentity/device_code_credential_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + deviceCode = "device_code" + deviceCodeResponse = `{"user_code":"test_code","device_code":"test_device_code","verification_uri":"https://microsoft.com/devicelogin","expires_in":900,"interval":5,"message":"To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code test_code to authenticate."}` + deviceCodeScopes = "user.read offline_access openid profile email" + authorizationPendingResponse = `{"error": "authorization_pending","error_description": "Authorization pending.","error_codes": [],"timestamp": "2019-12-01 19:00:00Z","trace_id": "2d091b0","correlation_id": "a999","error_uri": "https://login.contoso.com/error?code=0"}` + expiredTokenResponse = `{"error": "expired_token","error_description": "Token has expired.","error_codes": [],"timestamp": "2019-12-01 19:00:00Z","trace_id": "2d091b0","correlation_id": "a999","error_uri": "https://login.contoso.com/error?code=0"}` +) + +func TestDeviceCodeCredential_CreateAuthRequestSuccess(t *testing.T) { + handler := func(s string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + req, err := cred.client.createDeviceCodeAuthRequest(cred.tenantID, cred.clientID, deviceCode, []string{deviceCodeScopes}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpGrantType][0] != deviceCodeGrantType { + t.Fatalf("Unexpected grant type") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpDeviceCode][0] != deviceCode { + t.Fatalf("Unexpected username in the username header") + } + if reqQueryParams[qpScope][0] != deviceCodeScopes { + t.Fatalf("Unexpected scope in scope header") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } +} + +func TestDeviceCodeCredential_CreateAuthRequestEmptyTenant(t *testing.T) { + handler := func(s string) {} + cred, err := NewDeviceCodeCredential("", clientID, handler, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + req, err := cred.client.createDeviceCodeAuthRequest(cred.tenantID, cred.clientID, deviceCode, []string{deviceCodeScopes}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpGrantType][0] != deviceCodeGrantType { + t.Fatalf("Unexpected grant type") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpDeviceCode][0] != deviceCode { + t.Fatalf("Unexpected username in the username header") + } + if reqQueryParams[qpScope][0] != deviceCodeScopes { + t.Fatalf("Unexpected scope in scope header") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } + if req.Request.URL.Path != "/organizations/oauth2/v2.0/token" { + t.Fatalf("Did not set the right path when passing in an empty tenant ID") + } +} + +func TestDeviceCodeCredential_RequestNewDeviceCodeEmptyTenant(t *testing.T) { + handler := func(s string) {} + cred, err := NewDeviceCodeCredential("", clientID, handler, nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + req, err := cred.client.createDeviceCodeNumberRequest(cred.tenantID, cred.clientID, []string{deviceCodeScopes}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpScope][0] != deviceCodeScopes { + t.Fatalf("Unexpected scope in scope header") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } + if req.Request.URL.Path != "/organizations/oauth2/v2.0/devicecode" { + t.Fatalf("Did not set the right path when passing in an empty tenant ID") + } +} + +func TestDeviceCodeCredential_GetTokenSuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(deviceCodeResponse))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } + if tk.Token != "new_token" { + t.Fatalf("Received an unexpected value in azcore.AccessToken.Token") + } +} + +func TestDeviceCodeCredential_GetTokenInvalidCredentials(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one.") + } +} + +func TestDeviceCodeCredential_GetTokenAuthorizationPending(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(deviceCodeResponse))) + srv.AppendResponse(mock.WithBody([]byte(authorizationPendingResponse)), mock.WithStatusCode(http.StatusUnauthorized)) + srv.AppendResponse(mock.WithBody([]byte(authorizationPendingResponse)), mock.WithStatusCode(http.StatusUnauthorized)) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err != nil { + t.Fatalf("Expected an empty error but received %v", err) + } +} + +func TestDeviceCodeCredential_GetTokenExpiredToken(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(deviceCodeResponse))) + srv.AppendResponse(mock.WithBody([]byte(authorizationPendingResponse)), mock.WithStatusCode(http.StatusUnauthorized)) + srv.AppendResponse(mock.WithBody([]byte(expiredTokenResponse)), mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err == nil { + t.Fatalf("Expected an error but received none") + } +} + +func TestDeviceCodeCredential_GetTokenWithRefreshTokenFailure(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespError)), mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + cred.refreshToken = "refresh_token" + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } + var aadErr *AADAuthenticationFailedError + if !errors.As(err, &aadErr) { + t.Fatalf("Did not receive an AADAuthenticationFailedError but was expecting one") + } +} + +func TestDeviceCodeCredential_GetTokenWithRefreshTokenSuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + cred.refreshToken = "refresh_token" + tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}) + if err != nil { + t.Fatalf("Received an unexpected error: %s", err.Error()) + } + if tk.Token != "new_token" { + t.Fatalf("Unexpected value for token") + } +} + +func TestBearerPolicy_DeviceCodeCredential(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(deviceCodeResponse))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + handler := func(string) {} + cred, err := NewDeviceCodeCredential(tenantID, clientID, handler, &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{deviceCodeScopes}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + req := azcore.NewRequest(http.MethodGet, srv.URL()) + _, err = pipeline.Do(context.Background(), req) + if err != nil { + t.Fatalf("Expected an empty error but receive: %v", err) + } +} diff --git a/sdk/azidentity/environment_credential.go b/sdk/azidentity/environment_credential.go new file mode 100644 index 000000000000..936e840792e3 --- /dev/null +++ b/sdk/azidentity/environment_credential.go @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// NewEnvironmentCredential creates an instance of the ClientSecretCredential type and reads credential details from environment variables. +// If the expected environment variables are not found at this time, then a CredentialUnavailableError will be returned. +// options: The options used to configure the management of the requests sent to Azure Active Directory. +func NewEnvironmentCredential(options *TokenCredentialOptions) (*ClientSecretCredential, error) { + tenantID := os.Getenv("AZURE_TENANT_ID") + if tenantID == "" { + err := &CredentialUnavailableError{CredentialType: "Environment Credential", Message: "Missing environment variable AZURE_TENANT_ID"} + azcore.Log().Write(azcore.LogError, logCredentialError(err.CredentialType, err)) + return nil, err + } + + clientID := os.Getenv("AZURE_CLIENT_ID") + if clientID == "" { + err := &CredentialUnavailableError{CredentialType: "Environment Credential", Message: "Missing environment variable AZURE_CLIENT_ID"} + azcore.Log().Write(azcore.LogError, logCredentialError(err.CredentialType, err)) + return nil, err + } + + clientSecret := os.Getenv("AZURE_CLIENT_SECRET") + if clientSecret == "" { + err := &CredentialUnavailableError{CredentialType: "Environment Credential", Message: "Missing environment variable AZURE_CLIENT_SECRET"} + azcore.Log().Write(azcore.LogError, logCredentialError(err.CredentialType, err)) + return nil, err + } + azcore.Log().Write(LogCredential, "Azure Identity => NewEnvironmentCredential() invoking ClientSecretCredential") + return NewClientSecretCredential(tenantID, clientID, clientSecret, options) +} diff --git a/sdk/azidentity/environment_credential_test.go b/sdk/azidentity/environment_credential_test.go new file mode 100644 index 000000000000..53885baedf60 --- /dev/null +++ b/sdk/azidentity/environment_credential_test.go @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "errors" + "os" + "testing" +) + +func initEnvironmentVarsForTest() error { + err := os.Setenv("AZURE_TENANT_ID", tenantID) + if err != nil { + return err + } + err = os.Setenv("AZURE_CLIENT_ID", clientID) + if err != nil { + return err + } + err = os.Setenv("AZURE_CLIENT_SECRET", secret) + if err != nil { + return err + } + return nil +} + +func resetEnvironmentVarsForTest() error { + err := os.Setenv("AZURE_TENANT_ID", "") + if err != nil { + return err + } + err = os.Setenv("AZURE_CLIENT_ID", "") + if err != nil { + return err + } + err = os.Setenv("AZURE_CLIENT_SECRET", "") + if err != nil { + return err + } + return nil +} + +func TestEnvironmentCredential_TenantIDNotSet(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_CLIENT_ID", clientID) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_CLIENT_SECRET", secret) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + _, err = NewEnvironmentCredential(nil) + if err == nil { + t.Fatalf("Expected an error but received nil") + } + var credentialUnavailable *CredentialUnavailableError + if !errors.As(err, &credentialUnavailable) { + t.Fatalf("Expected a credential unavailable error, instead received: %T", err) + } +} + +func TestEnvironmentCredential_ClientIDNotSet(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_TENANT_ID", tenantID) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_CLIENT_SECRET", secret) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + _, err = NewEnvironmentCredential(nil) + if err == nil { + t.Fatalf("Expected an error but received nil") + } + var credentialUnavailable *CredentialUnavailableError + if !errors.As(err, &credentialUnavailable) { + t.Fatalf("Expected a credential unavailable error, instead received: %T", err) + } +} + +func TestEnvironmentCredential_ClientSecretNotSet(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_TENANT_ID", tenantID) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + err = os.Setenv("AZURE_CLIENT_ID", clientID) + if err != nil { + t.Fatalf("Unexpected error when initializing environment variables: %v", err) + } + _, err = NewEnvironmentCredential(nil) + if err == nil { + t.Fatalf("Expected an error but received nil") + } + var credentialUnavailable *CredentialUnavailableError + if !errors.As(err, &credentialUnavailable) { + t.Fatalf("Expected a credential unavailable error, instead received: %T", err) + } +} diff --git a/sdk/azidentity/fingerprint.go b/sdk/azidentity/fingerprint.go new file mode 100644 index 000000000000..b047293078d6 --- /dev/null +++ b/sdk/azidentity/fingerprint.go @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "bufio" + "bytes" + "crypto/sha1" + "encoding/pem" + "errors" + "fmt" + "os" +) + +// fingerprint type wraps a byte slice that contains the corresponding SHA-1 fingerprint for the client's certificate +type fingerprint []byte + +// String represents the fingerprint digest as a series of +// colon-delimited hexadecimal octets. +func (f fingerprint) String() string { + var buf bytes.Buffer + for i, b := range f { + if i > 0 { + fmt.Fprintf(&buf, ":") + } + fmt.Fprintf(&buf, "%02x", b) + } + return buf.String() +} + +// spkiFingerprint calculates the fingerprint of the certificate based on it's Subject Public Key Info with the SHA-1 +// signing algorithm. +func spkiFingerprint(cert string) (fingerprint, error) { + privateKeyFile, err := os.Open(cert) + if err != nil { + return nil, fmt.Errorf("%s: %w", cert, err) + } + defer privateKeyFile.Close() + + pemFileInfo, err := privateKeyFile.Stat() + if err != nil { + return nil, err + } + + var size int64 = pemFileInfo.Size() + pemBytes := make([]byte, size) + buffer := bufio.NewReader(privateKeyFile) + _, err = buffer.Read(pemBytes) + if err != nil { + return nil, err + } + // Get first block of PEM file + data, rest := pem.Decode([]byte(pemBytes)) + const certificateBlock = "CERTIFICATE" + if data.Type != certificateBlock { + for len(rest) > 0 { + data, rest = pem.Decode(rest) + if data.Type == certificateBlock { + // Sign the CERTIFICATE block with SHA1 + h := sha1.New() + _, err := h.Write(data.Bytes) + if err != nil { + return nil, err + } + + return fingerprint(h.Sum(nil)), nil + } + } + return nil, errors.New("Cannot find CERTIFICATE in file") + } + h := sha1.New() + _, err = h.Write(data.Bytes) + if err != nil { + return nil, err + } + + return fingerprint(h.Sum(nil)), nil +} diff --git a/sdk/azidentity/go.mod b/sdk/azidentity/go.mod new file mode 100644 index 000000000000..ec84ba11f6ce --- /dev/null +++ b/sdk/azidentity/go.mod @@ -0,0 +1,8 @@ +module github.com/Azure/azure-sdk-for-go/sdk/azidentity + +go 1.13 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0 +) diff --git a/sdk/azidentity/go.sum b/sdk/azidentity/go.sum new file mode 100644 index 000000000000..00dfdfebed95 --- /dev/null +++ b/sdk/azidentity/go.sum @@ -0,0 +1,4 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0 h1:VdhfbVpQ3dkhXYOx/Wj1+utikcZkZSZSmpqmXWwaNJY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0/go.mod h1:hL9TGc07RkJVzDIBxsYXC/r0M+YiRkvl4z1elXCD+8s= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0 h1:cLpVMIkXC/umSP9DMz9I6FttDWJAsmvhpaB6MlkagGY= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0/go.mod h1:Q+TCQnSr+clUU0JU+xrHZ3slYCxw17AOFdvWFpQXjAY= diff --git a/sdk/azidentity/jwt.go b/sdk/azidentity/jwt.go new file mode 100644 index 000000000000..125ee31e5d24 --- /dev/null +++ b/sdk/azidentity/jwt.go @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" +) + +// headerJWT type contains the fields necessary to create a JSON Web Token including the x5t field which must contain a x.509 certificate thumbprint +type headerJWT struct { + Typ string `json:"typ"` + Alg string `json:"alg"` + X5t string `json:"x5t"` +} + +// payloadJWT type contains all fields that are necessary when creating a JSON Web Token payload section +type payloadJWT struct { + JTI string `json:"jti"` + AUD string `json:"aud"` + ISS string `json:"iss"` + SUB string `json:"sub"` + NBF int64 `json:"nbf"` + EXP int64 `json:"exp"` +} + +// createClientAssertionJWT build the JWT header, payload and signature, +// then returns a string for the JWT assertion +func createClientAssertionJWT(clientID string, audience string, clientCertificate string) (string, error) { + fingerprint, err := spkiFingerprint(clientCertificate) + if err != nil { + return "", err + } + + headerData := headerJWT{ + Typ: "JWT", + Alg: "RS256", + X5t: base64.RawURLEncoding.EncodeToString(fingerprint), + } + + headerJSON, err := json.Marshal(headerData) + if err != nil { + return "", fmt.Errorf("Marshal headerJWT: %w", err) + } + header := base64.RawURLEncoding.EncodeToString(headerJSON) + + payloadData := payloadJWT{ + JTI: uuid.New().String(), + AUD: audience, + ISS: clientID, + SUB: clientID, + NBF: time.Now().Unix(), + EXP: time.Now().Add(30 * time.Minute).Unix(), + } + + payloadJSON, err := json.Marshal(payloadData) + if err != nil { + return "", fmt.Errorf("Marshal payloadJWT: %w", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) + result := header + "." + payload + hashed := []byte(result) + hashedSum := sha256.Sum256(hashed) + cryptoRand := rand.Reader + + privateKey, err := getPrivateKey(clientCertificate) + if err != nil { + return "", err + } + + signed, err := rsa.SignPKCS1v15(cryptoRand, privateKey, crypto.SHA256, hashedSum[:]) + if err != nil { + return "", err + } + + signature := base64.RawURLEncoding.EncodeToString(signed) + + return result + "." + signature, nil +} diff --git a/sdk/azidentity/logging.go b/sdk/azidentity/logging.go new file mode 100644 index 000000000000..4e6ca943c000 --- /dev/null +++ b/sdk/azidentity/logging.go @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "fmt" + "os" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// LogCredential is the log classification that can be used for logging Azure Identity related information +const LogCredential azcore.LogClassification = "credential" + +// log environment variables that can be used for credential types +func logEnvVars() { + if !azcore.Log().Should(LogCredential) { + return + } + // Log available environment variables + envVars := []string{} + if envCheck := os.Getenv("AZURE_TENANT_ID"); len(envCheck) > 0 { + envVars = append(envVars, "AZURE_TENANT_ID") + } + if envCheck := os.Getenv("AZURE_CLIENT_ID"); len(envCheck) > 0 { + envVars = append(envVars, "AZURE_CLIENT_ID") + } + if envCheck := os.Getenv("AZURE_CLIENT_SECRET"); len(envCheck) > 0 { + envVars = append(envVars, "AZURE_CLIENT_SECRET") + } + if envCheck := os.Getenv("AZURE_AUTHORITY_HOST"); len(envCheck) > 0 { + envVars = append(envVars, "AZURE_AUTHORITY_HOST") + } + if envCheck := os.Getenv("AZURE_CLI_PATH"); len(envCheck) > 0 { + envVars = append(envVars, "AZURE_CLI_PATH") + } + if len(envVars) > 0 { + azcore.Log().Write(LogCredential, fmt.Sprintf("Azure Identity => Found the following environment variables: %s", strings.Join(envVars, ", "))) + } +} + +func logGetTokenSuccess(cred azcore.TokenCredential, opts azcore.TokenRequestOptions) string { + msg := fmt.Sprintf("Azure Identity => GetToken() result for %T: SUCCESS\n", cred) + msg += fmt.Sprintf("Azure Identity => Scopes: [%s]", strings.Join(opts.Scopes, ", ")) + return msg +} + +func logGetTokenFailure(credName string) string { + return fmt.Sprintf("Azure Identity => ERROR in GetToken() call for %s. Please check the log for the error.", credName) +} + +func logCredentialError(credName string, err error) string { + return fmt.Sprintf("Azure Identity => ERROR in %s: %s", credName, err.Error()) +} + +func logMSIEnv(msi msiType) string { + switch msi { + case 1: + return "Azure Identity => Managed Identity environment: IMDS" + case 2: + return "Azure Identity => Managed Identity environment: MSI_ENDPOINT" + case 3: + return "Azure Identity => Managed Identity environment: MSI_ENDPOINT" + case 4: + return "Azure Identity => Managed Identity environment: Unavailable" + default: + return "Azure Identity => Managed Identity environment: Unknown" + } +} + +func addGetTokenFailureLogs(credName string, err error) { + azcore.Log().Write(azcore.LogError, logCredentialError(credName, err)) + azcore.Log().Write(azcore.LogError, logGetTokenFailure(credName)) +} diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go new file mode 100644 index 000000000000..7fc1dbf19dff --- /dev/null +++ b/sdk/azidentity/managed_identity_client.go @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +const ( + imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" +) + +const ( + msiEndpointEnvironemntVariable = "MSI_ENDPOINT" + msiSecretEnvironemntVariable = "MSI_SECRET" + appServiceMsiAPIVersion = "2017-09-01" + imdsAPIVersion = "2018-02-01" +) + +type msiType int + +const ( + msiTypeUnknown msiType = 0 + msiTypeIMDS msiType = 1 + msiTypeAppService msiType = 2 + msiTypeCloudShell msiType = 3 + msiTypeUnavailable msiType = 4 +) + +// managedIdentityClient provides the base for authenticating in managed identity environments +// This type includes an azcore.Pipeline and TokenCredentialOptions. +type managedIdentityClient struct { + pipeline azcore.Pipeline + imdsAPIVersion string + imdsAvailableTimeoutMS time.Duration + msiType msiType + endpoint *url.URL +} + +type wrappedNumber json.Number + +func (n *wrappedNumber) UnmarshalJSON(b []byte) error { + c := string(b) + if c == "\"\"" { + return nil + } + return json.Unmarshal(b, (*json.Number)(n)) +} + +var ( + imdsURL *url.URL // these are initialized in the init func and are R/O afterwards + defaultMSIOpts = newDefaultManagedIdentityOptions() +) + +func init() { + // The error is checked for in managed_identity_client_test.go and should not be ignored if the test fails + imdsURL, _ = url.Parse(imdsEndpoint) +} + +func newDefaultManagedIdentityOptions() *ManagedIdentityCredentialOptions { + return &ManagedIdentityCredentialOptions{ + HTTPClient: azcore.DefaultHTTPClientTransport(), + } +} + +// newManagedIdentityClient creates a new instance of the ManagedIdentityClient with the ManagedIdentityCredentialOptions +// that are passed into it along with a default pipeline. +// options: ManagedIdentityCredentialOptions configure policies for the pipeline and the authority host that +// will be used to retrieve tokens and authenticate +func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *managedIdentityClient { + logEnvVars() + options = options.setDefaultValues() + return &managedIdentityClient{ + pipeline: newDefaultMSIPipeline(*options), // a pipeline that includes the specific requirements for MSI authentication, such as custom retry policy options + imdsAPIVersion: imdsAPIVersion, // this field will be set to whatever value exists in the constant and is used when creating requests to IMDS + imdsAvailableTimeoutMS: 500, // we allow a timeout of 500 ms since the endpoint might be slow to respond + msiType: msiTypeUnknown, // when creating a new managedIdentityClient, the current MSI type is unknown and will be tested for and replaced once authenticate() is called from GetToken on the credential side + } +} + +// authenticate creates an authentication request for a Managed Identity and returns the resulting Access Token if successful. +// ctx: The current context for controlling the request lifetime. +// clientID: The client (application) ID of the service principal. +// scopes: The scopes required for the token. +func (c *managedIdentityClient) authenticate(ctx context.Context, clientID string, scopes []string) (*azcore.AccessToken, error) { + currentMSI, err := c.getMSIType(ctx) + if err != nil { + return nil, err + } + // This condition should never be true since getMSIType returns an error in these cases + // if MSI is unavailable or we were unable to determine the type return a nil access token + if currentMSI == msiTypeUnavailable || currentMSI == msiTypeUnknown { + return nil, &CredentialUnavailableError{CredentialType: "Managed Identity Credential", Message: "Please make sure you are running in a managed identity environment, such as a VM, Azure Functions, Cloud Shell, etc..."} + } + + AT, err := c.sendAuthRequest(ctx, currentMSI, clientID, scopes) + if err != nil { + return nil, err + } + return AT, nil +} + +func (c *managedIdentityClient) sendAuthRequest(ctx context.Context, msiType msiType, clientID string, scopes []string) (*azcore.AccessToken, error) { + msg, err := c.createAuthRequest(msiType, clientID, scopes) + if err != nil { + return nil, err + } + + resp, err := c.pipeline.Do(ctx, msg) + if err != nil { + return nil, err + } + + if resp.HasStatusCode(successStatusCodes[:]...) { + return c.createAccessToken(resp) + } + + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} +} + +func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore.AccessToken, error) { + value := struct { + // these are the only fields that we use + Token string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid + ExpiresOn string `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string + }{} + if err := res.UnmarshalAsJSON(&value); err != nil { + return nil, fmt.Errorf("internal AccessToken: %w", err) + } + if value.ExpiresIn != "" { + expiresIn, err := json.Number(value.ExpiresIn).Int64() + if err != nil { + return nil, err + } + return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresIn)).UTC()}, nil + } + if expiresOn, err := strconv.Atoi(value.ExpiresOn); err == nil { + return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresOn)).UTC()}, nil + } + // this is the case when expires_on is a time string + // this is the format of the string coming from the service + if expiresOn, err := time.Parse("01/02/2006 15:04:05 PM +00:00", value.ExpiresOn); err == nil { // the date string specified in the layout param of time.Parse cannot be changed, Golang expects whatever layout to always signify January 2, 2006 at 3:04 PM + eo := expiresOn.UTC() + return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil + } else { + return nil, err + } +} + +func (c *managedIdentityClient) createAuthRequest(msiType msiType, clientID string, scopes []string) (*azcore.Request, error) { + switch msiType { + case msiTypeIMDS: + return c.createIMDSAuthRequest(scopes), nil + case msiTypeAppService: + return c.createAppServiceAuthRequest(clientID, scopes), nil + case msiTypeCloudShell: + return c.createCloudShellAuthRequest(clientID, scopes) + default: + errorMsg := "" + switch msiType { + case msiTypeUnavailable: + errorMsg = "unavailable" + default: + errorMsg = "unknown" + } + + return nil, &CredentialUnavailableError{CredentialType: "Managed Identity Credential", Message: "Make sure you are running in a valid Managed Identity Environment. Status: " + errorMsg} + } +} + +func (c *managedIdentityClient) createIMDSAuthRequest(scopes []string) *azcore.Request { + request := azcore.NewRequest(http.MethodGet, *c.endpoint) + request.Header.Set(azcore.HeaderMetadata, "true") + q := request.URL.Query() + q.Add("api-version", c.imdsAPIVersion) + q.Add("resource", strings.Join(scopes, " ")) + request.URL.RawQuery = q.Encode() + + return request +} + +func (c *managedIdentityClient) createAppServiceAuthRequest(clientID string, scopes []string) *azcore.Request { + request := azcore.NewRequest(http.MethodGet, *c.endpoint) + request.Header.Set("secret", os.Getenv(msiSecretEnvironemntVariable)) + q := request.URL.Query() + q.Add("api-version", appServiceMsiAPIVersion) + q.Add("resource", strings.Join(scopes, " ")) + if clientID != "" { + q.Add(qpClientID, clientID) + } + request.URL.RawQuery = q.Encode() + + return request +} + +func (c *managedIdentityClient) createCloudShellAuthRequest(clientID string, scopes []string) (*azcore.Request, error) { + request := azcore.NewRequest(http.MethodPost, *c.endpoint) + request.Header.Set(azcore.HeaderContentType, azcore.HeaderURLEncoded) + request.Header.Set(azcore.HeaderMetadata, "true") + data := url.Values{} + data.Set("resource", strings.Join(scopes, " ")) + if clientID != "" { + data.Set("client_id", clientID) + } + dataEncoded := data.Encode() + body := azcore.NopCloser(strings.NewReader(dataEncoded)) + err := request.SetBody(body) + if err != nil { + return nil, err + } + return request, nil +} + +func (c *managedIdentityClient) getMSIType(ctx context.Context) (msiType, error) { + if c.msiType == msiTypeUnknown { // if we haven't already determined the msi type + if endpointEnvVar := os.Getenv(msiEndpointEnvironemntVariable); endpointEnvVar != "" { // if the env var MSI_ENDPOINT is set + endpoint, err := url.Parse(endpointEnvVar) + if err != nil { + return msiTypeUnknown, err + } + c.endpoint = endpoint + if secretEnvVar := os.Getenv(msiSecretEnvironemntVariable); secretEnvVar != "" { // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService + c.msiType = msiTypeAppService + } else { // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell + c.msiType = msiTypeCloudShell + } + } else if c.imdsAvailable(ctx) { // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the MsiType is Imds. This will timeout after 500 milliseconds + c.endpoint = imdsURL + c.msiType = msiTypeIMDS + } else { // if MSI_ENDPOINT is NOT set and IMDS enpoint is not available ManagedIdentity is not available + c.msiType = msiTypeUnavailable + return c.msiType, &CredentialUnavailableError{CredentialType: "Managed Identity Credential", Message: "Make sure you are running in a valid Managed Identity Environment"} + } + } + return c.msiType, nil +} + +func (c *managedIdentityClient) imdsAvailable(ctx context.Context) bool { + tempCtx, cancel := context.WithTimeout(ctx, c.imdsAvailableTimeoutMS*time.Millisecond) + defer cancel() + request := azcore.NewRequest(http.MethodGet, *imdsURL) + q := request.URL.Query() + q.Add("api-version", c.imdsAPIVersion) + request.URL.RawQuery = q.Encode() + _, err := c.pipeline.Do(tempCtx, request) + return err == nil +} diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go new file mode 100644 index 000000000000..f38461c495ad --- /dev/null +++ b/sdk/azidentity/managed_identity_client_test.go @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "net/url" + "testing" +) + +func TestIMDSEndpointParse(t *testing.T) { + _, err := url.Parse(imdsEndpoint) + if err != nil { + t.Fatalf("Failed to parse the IMDS endpoint: %v", err) + } +} + +// func TestNewDefaultMSIPipeline(t *testing.T) { +// p := newDefaultMSIPipeline(ManagedIdentityCredentialOptions{}) +// } diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go new file mode 100644 index 000000000000..a20c277332fa --- /dev/null +++ b/sdk/azidentity/managed_identity_credential.go @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "os" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// ManagedIdentityCredentialOptions contains parameters that can be used to configure the pipeline used with Managed Identity Credential. +type ManagedIdentityCredentialOptions struct { + // HTTPClient sets the transport for making HTTP requests. + // Leave this as nil to use the default HTTP transport. + HTTPClient azcore.Transport + + // LogOptions configures the built-in request logging policy behavior. + LogOptions azcore.RequestLogOptions + + // Telemetry configures the built-in telemetry policy behavior. + Telemetry azcore.TelemetryOptions +} + +func (m *ManagedIdentityCredentialOptions) setDefaultValues() *ManagedIdentityCredentialOptions { + if m == nil { + m = defaultMSIOpts + } + return m +} + +// ManagedIdentityCredential attempts authentication using a managed identity that has been assigned to the deployment environment. This authentication type works in several +// managed identity environments such as Azure VMs, App Service, Azure Functions, Azure CloudShell, among others. More information about configuring managed identities can be found here: +// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview +type ManagedIdentityCredential struct { + clientID string + client *managedIdentityClient +} + +// NewManagedIdentityCredential creates an instance of the ManagedIdentityCredential capable of authenticating a resource that has a managed identity. +// clientID: The client ID to authenticate for a user assigned managed identity. +// options: ManagedIdentityCredentialOptions that configure the pipeline for requests sent to Azure Active Directory. +// More information on user assigned managed identities cam be found here: +// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview#how-a-user-assigned-managed-identity-works-with-an-azure-vm +func NewManagedIdentityCredential(clientID string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) { + // Create a new Managed Identity Client with default options + client := newManagedIdentityClient(options) + // Create a context that will timeout after 500 milliseconds (that is the amount of time designated to find out if the IMDS endpoint is available) + ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(client.imdsAvailableTimeoutMS)*time.Millisecond) + defer cancelFunc() + msiType, err := client.getMSIType(ctx) + // If there is an error that means that the code is not running in a Managed Identity environment + if err != nil { + credErr := &CredentialUnavailableError{CredentialType: "Managed Identity Credential", Message: "Please make sure you are running in a managed identity environment, such as a VM, Azure Functions, Cloud Shell, etc..."} + azcore.Log().Write(azcore.LogError, logCredentialError(credErr.CredentialType, credErr)) + return nil, credErr + } + // Assign the msiType discovered onto the client + client.msiType = msiType + // check if no clientID is specified then check if it exists in an environment variable + if len(clientID) == 0 { + clientID = os.Getenv("AZURE_CLIENT_ID") + } + return &ManagedIdentityCredential{clientID: clientID, client: client}, nil +} + +// GetToken obtains an AccessToken from the Managed Identity service if available. +// scopes: The list of scopes for which the token will have access. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + tk, err := c.client.authenticate(ctx, c.clientID, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Managed Identity Credential", err) + return nil, err + } + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + azcore.Log().Write(LogCredential, logMSIEnv(c.client.msiType)) + return tk, err +} + +// AuthenticationPolicy implements the azcore.Credential interface on ManagedIdentityCredential. +// Please note: the TokenRequestOptions included in AuthenticationPolicyOptions must be a slice of resources in this case and not scopes +func (c *ManagedIdentityCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + // The following code will remove the /.default suffix from any scopes passed into the method since ManagedIdentityCredentials expect a resource string instead of a scope string + for i := range options.Options.Scopes { + options.Options.Scopes[i] = strings.TrimSuffix(options.Options.Scopes[i], defaultSuffix) + } + return newBearerTokenPolicy(c, options) +} diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go new file mode 100644 index 000000000000..5177fb3475e1 --- /dev/null +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "net/http" + "net/url" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + msiScope = "https://storage.azure.com" + appServiceTokenSuccessResp = `{"access_token": "new_token", "expires_on": "09/14/2017 00:00:00 PM +00:00", "resource": "https://vault.azure.net", "token_type": "Bearer"}` + expiresOnIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": "1560974028", "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}` +) + +func TestManagedIdentityCredential_GetTokenInCloudShellLive(t *testing.T) { + if len(os.Getenv("MSI_ENDPOINT")) == 0 { + t.Skip() + } + msiCred, err := NewManagedIdentityCredential(clientID, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } +} + +func TestManagedIdentityCredential_GetTokenInCloudShellMock(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + msiCred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } +} + +func TestManagedIdentityCredential_GetTokenInCloudShellMockFail(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + msiCred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } +} + +func TestManagedIdentityCredential_GetTokenInAppServiceMock(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(appServiceTokenSuccessResp))) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + _ = os.Setenv("MSI_SECRET", "secret") + msiCred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } +} + +func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(expiresOnIntResp))) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + _ = os.Setenv("MSI_SECRET", "secret") + msiCred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } +} + +func TestManagedIdentityCredential_GetTokenInAppServiceMockFail(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + _ = os.Setenv("MSI_SECRET", "secret") + msiCred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } +} + +// func TestManagedIdentityCredential_GetTokenIMDSMock(t *testing.T) { +// timeout := time.After(5 * time.Second) +// done := make(chan bool) +// go func() { +// err := resetEnvironmentVarsForTest() +// if err != nil { +// t.Fatalf("Unable to set environment variables") +// } +// srv, close := mock.NewServer() +// defer close() +// srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) +// msiCred := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: srv}) +// _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) +// if err == nil { +// t.Fatalf("Cannot run IMDS test in this environment") +// } +// time.Sleep(550 * time.Millisecond) +// done <- true +// }() + +// select { +// case <-timeout: +// t.Fatal("Test didn't finish in time") +// case <-done: +// } +// } + +func TestManagedIdentityCredential_NewManagedIdentityCredentialFail(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusUnauthorized)) + _ = os.Setenv("MSI_ENDPOINT", "https://t .com") + _, err = NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err == nil { + t.Fatalf("Expected an error but did not receive one") + } +} + +func TestBearerPolicy_ManagedIdentityCredential(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + cred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{msiScope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatalf("Expected an empty error but receive: %v", err) + } +} + +func TestManagedIdentityCredential_GetTokenUnexpectedJSON(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespMalformed))) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + msiCred, err := NewManagedIdentityCredential(clientID, &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err == nil { + t.Fatalf("Expected a JSON marshal error but received nil") + } +} + +func TestManagedIdentityCredential_CreateIMDSAuthRequest(t *testing.T) { + cred, err := NewManagedIdentityCredential(clientID, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cred.client.endpoint = imdsURL + req := cred.client.createIMDSAuthRequest([]string{msiScope}) + if req.Request.Header.Get(azcore.HeaderMetadata) != "true" { + t.Fatalf("Unexpected value for Content-Type header") + } + reqQueryParams, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + t.Fatalf("Unable to parse IMDS query params: %v", err) + } + if reqQueryParams["api-version"][0] != imdsAPIVersion { + t.Fatalf("Unexpected IMDS API version") + } + if reqQueryParams["resource"][0] != msiScope { + t.Fatalf("Unexpected resource in resource query param") + } + if req.Request.URL.Host != imdsURL.Host { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "http" { + t.Fatalf("Wrong request scheme") + } +} + +func TestManagedIdentityCredential_GetTokenEnvVar(t *testing.T) { + err := resetEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Unable to set environment variables") + } + err = os.Setenv("AZURE_CLIENT_ID", "test_client_id") + if err != nil { + t.Fatal(err) + } + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + testURL := srv.URL() + _ = os.Setenv("MSI_ENDPOINT", testURL.String()) + msiCred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{HTTPClient: srv}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + at, err := msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } + if at.Token != "new_token" { + t.Fatalf("Did not receive the correct access token") + } +} diff --git a/sdk/azidentity/testdata/certificate.pem b/sdk/azidentity/testdata/certificate.pem new file mode 100644 index 000000000000..4b66bfa021a0 --- /dev/null +++ b/sdk/azidentity/testdata/certificate.pem @@ -0,0 +1,49 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDL1hG+JYCfIPp3 +tlZ05J4pYIJ3Ckfs432bE3rYuWlR2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRY +OCI69s4+lP3DwR8uBCp9xyVkF8thXfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+Qf +oxAb6tx0kEc7V3ozBLWoIDJjfwJ3NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIr +Aa7pxHzo/Nd0U3e7z+DlBcJV7dY6TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bC +lG0u7unS7QOBMd6bOGkeL+Bc+n22slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpX +wj/Ek0F7AgMBAAECggEAblU3UWdXUcs2CCqIbcl52wfEVs8X05/n01MeAcWKvqYG +hvGcz7eLvhir5dQoXcF3VhybMrIe6C4WcBIiZSxGwxU+rwEP8YaLwX1UPfOrQM7s +sZTdFTLWfUslO3p7q300fdRA92iG9COMDZvkElh0cBvQksxs9sSr149l9vk+ymtC +uBhZtHG6Ki0BIMBNC9jGUqDuOatXl/dkK4tNjXrNJT7tVwzPaqnNALIWl6B+k9oQ +m1oNhSH2rvs9tw2ITXfIoIk9KdOMjQVUD43wKOaz0hNZhUsb1OFuls7UtRzaFcZH +rMd/M8DtA104QTTlHK+XS7r+nqdv7+ZyB+suTdM+oQKBgQDxCrJZU3hJ0eJ4VYhK +xGDfVGNpYxNkQ4CDB9fwRNbFr/Ck3kgzfE9QxTx1pJOolVmfuFmk9B86in4UNy91 +KdaqT79AU5RdOBXNN6tuMbLC0AVqe8sZq+1vWVVwbCstffxEMmyW1Ju/FLYPl2Zp +e5P96dBh5B3mXrQtpDJ0RkxxaQKBgQDYfE6tQQnQSs2ewD6ae8Mu6j8ueDlVoZ37 +vze1QdBasR26xu2H8XBt3u41zc524BwQsB1GE1tnC8ZylrqwVEayK4FesSQRCO6o +yK8QSdb06I5J4TaN+TppCDPLzstOh0Dmxp+iFUGoErb7AEOLAJ/VebhF9kBZObL/ +HYy4Es+bQwKBgHW/4vYuB3IQXNCp/+V+X1BZ+iJOaves3gekekF+b2itFSKFD8JO +9LQhVfKmTheptdmHhgtF0keXxhV8C+vxX1Ndl7EF41FSh5vzmQRAtPHkCvFEviex +TFD70/gSb1lO1UA/Xbqk69yBcprVPAtFejss0EYx2MVj+CLftmIEwW0ZAoGBAIMG +EVQ45eikLXjkn78+Iq7VZbIJX6IdNBH29I+GqsUJJ5Yw6fh6P3KwF3qG+mvmTfYn +sUAFXS+r58rYwVsRVsxlGmKmUc7hmhibhaEVH72QtvWuEiexbRG+viKfIVuA7t39 +3wXpWZiQ4yBdU4Pgt9wrVEU7ukyGaHiReOa7s90jAoGAJc0K7smn98YutQQ+g2ur +ybfnsl0YdsksaP2S2zvZUmNevKPrgnaIDDabOlhYYga+AK1G3FQ7/nefUgiIg1Nd +kr+T6Q4osS3xHB6Az9p/jaF4R2KaWN2nNVCn7ecsmPxDdM7k1vLxaT26vwO9OP5f +YU/5CeIzrfA5nQyPZkOXZBk= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdk/azidentity/testdata/certificate_empty.pem b/sdk/azidentity/testdata/certificate_empty.pem new file mode 100644 index 000000000000..24fc8011f2a7 --- /dev/null +++ b/sdk/azidentity/testdata/certificate_empty.pem @@ -0,0 +1,21 @@ +-----BEGIN BLOCK----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END BLOCK----- \ No newline at end of file diff --git a/sdk/azidentity/testdata/certificate_formatA.pem b/sdk/azidentity/testdata/certificate_formatA.pem new file mode 100644 index 000000000000..4b66bfa021a0 --- /dev/null +++ b/sdk/azidentity/testdata/certificate_formatA.pem @@ -0,0 +1,49 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDL1hG+JYCfIPp3 +tlZ05J4pYIJ3Ckfs432bE3rYuWlR2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRY +OCI69s4+lP3DwR8uBCp9xyVkF8thXfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+Qf +oxAb6tx0kEc7V3ozBLWoIDJjfwJ3NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIr +Aa7pxHzo/Nd0U3e7z+DlBcJV7dY6TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bC +lG0u7unS7QOBMd6bOGkeL+Bc+n22slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpX +wj/Ek0F7AgMBAAECggEAblU3UWdXUcs2CCqIbcl52wfEVs8X05/n01MeAcWKvqYG +hvGcz7eLvhir5dQoXcF3VhybMrIe6C4WcBIiZSxGwxU+rwEP8YaLwX1UPfOrQM7s +sZTdFTLWfUslO3p7q300fdRA92iG9COMDZvkElh0cBvQksxs9sSr149l9vk+ymtC +uBhZtHG6Ki0BIMBNC9jGUqDuOatXl/dkK4tNjXrNJT7tVwzPaqnNALIWl6B+k9oQ +m1oNhSH2rvs9tw2ITXfIoIk9KdOMjQVUD43wKOaz0hNZhUsb1OFuls7UtRzaFcZH +rMd/M8DtA104QTTlHK+XS7r+nqdv7+ZyB+suTdM+oQKBgQDxCrJZU3hJ0eJ4VYhK +xGDfVGNpYxNkQ4CDB9fwRNbFr/Ck3kgzfE9QxTx1pJOolVmfuFmk9B86in4UNy91 +KdaqT79AU5RdOBXNN6tuMbLC0AVqe8sZq+1vWVVwbCstffxEMmyW1Ju/FLYPl2Zp +e5P96dBh5B3mXrQtpDJ0RkxxaQKBgQDYfE6tQQnQSs2ewD6ae8Mu6j8ueDlVoZ37 +vze1QdBasR26xu2H8XBt3u41zc524BwQsB1GE1tnC8ZylrqwVEayK4FesSQRCO6o +yK8QSdb06I5J4TaN+TppCDPLzstOh0Dmxp+iFUGoErb7AEOLAJ/VebhF9kBZObL/ +HYy4Es+bQwKBgHW/4vYuB3IQXNCp/+V+X1BZ+iJOaves3gekekF+b2itFSKFD8JO +9LQhVfKmTheptdmHhgtF0keXxhV8C+vxX1Ndl7EF41FSh5vzmQRAtPHkCvFEviex +TFD70/gSb1lO1UA/Xbqk69yBcprVPAtFejss0EYx2MVj+CLftmIEwW0ZAoGBAIMG +EVQ45eikLXjkn78+Iq7VZbIJX6IdNBH29I+GqsUJJ5Yw6fh6P3KwF3qG+mvmTfYn +sUAFXS+r58rYwVsRVsxlGmKmUc7hmhibhaEVH72QtvWuEiexbRG+viKfIVuA7t39 +3wXpWZiQ4yBdU4Pgt9wrVEU7ukyGaHiReOa7s90jAoGAJc0K7smn98YutQQ+g2ur +ybfnsl0YdsksaP2S2zvZUmNevKPrgnaIDDabOlhYYga+AK1G3FQ7/nefUgiIg1Nd +kr+T6Q4osS3xHB6Az9p/jaF4R2KaWN2nNVCn7ecsmPxDdM7k1vLxaT26vwO9OP5f +YU/5CeIzrfA5nQyPZkOXZBk= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdk/azidentity/testdata/certificate_formatB.pem b/sdk/azidentity/testdata/certificate_formatB.pem new file mode 100644 index 000000000000..3896c163dfed --- /dev/null +++ b/sdk/azidentity/testdata/certificate_formatB.pem @@ -0,0 +1,49 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDL1hG+JYCfIPp3 +tlZ05J4pYIJ3Ckfs432bE3rYuWlR2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRY +OCI69s4+lP3DwR8uBCp9xyVkF8thXfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+Qf +oxAb6tx0kEc7V3ozBLWoIDJjfwJ3NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIr +Aa7pxHzo/Nd0U3e7z+DlBcJV7dY6TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bC +lG0u7unS7QOBMd6bOGkeL+Bc+n22slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpX +wj/Ek0F7AgMBAAECggEAblU3UWdXUcs2CCqIbcl52wfEVs8X05/n01MeAcWKvqYG +hvGcz7eLvhir5dQoXcF3VhybMrIe6C4WcBIiZSxGwxU+rwEP8YaLwX1UPfOrQM7s +sZTdFTLWfUslO3p7q300fdRA92iG9COMDZvkElh0cBvQksxs9sSr149l9vk+ymtC +uBhZtHG6Ki0BIMBNC9jGUqDuOatXl/dkK4tNjXrNJT7tVwzPaqnNALIWl6B+k9oQ +m1oNhSH2rvs9tw2ITXfIoIk9KdOMjQVUD43wKOaz0hNZhUsb1OFuls7UtRzaFcZH +rMd/M8DtA104QTTlHK+XS7r+nqdv7+ZyB+suTdM+oQKBgQDxCrJZU3hJ0eJ4VYhK +xGDfVGNpYxNkQ4CDB9fwRNbFr/Ck3kgzfE9QxTx1pJOolVmfuFmk9B86in4UNy91 +KdaqT79AU5RdOBXNN6tuMbLC0AVqe8sZq+1vWVVwbCstffxEMmyW1Ju/FLYPl2Zp +e5P96dBh5B3mXrQtpDJ0RkxxaQKBgQDYfE6tQQnQSs2ewD6ae8Mu6j8ueDlVoZ37 +vze1QdBasR26xu2H8XBt3u41zc524BwQsB1GE1tnC8ZylrqwVEayK4FesSQRCO6o +yK8QSdb06I5J4TaN+TppCDPLzstOh0Dmxp+iFUGoErb7AEOLAJ/VebhF9kBZObL/ +HYy4Es+bQwKBgHW/4vYuB3IQXNCp/+V+X1BZ+iJOaves3gekekF+b2itFSKFD8JO +9LQhVfKmTheptdmHhgtF0keXxhV8C+vxX1Ndl7EF41FSh5vzmQRAtPHkCvFEviex +TFD70/gSb1lO1UA/Xbqk69yBcprVPAtFejss0EYx2MVj+CLftmIEwW0ZAoGBAIMG +EVQ45eikLXjkn78+Iq7VZbIJX6IdNBH29I+GqsUJJ5Yw6fh6P3KwF3qG+mvmTfYn +sUAFXS+r58rYwVsRVsxlGmKmUc7hmhibhaEVH72QtvWuEiexbRG+viKfIVuA7t39 +3wXpWZiQ4yBdU4Pgt9wrVEU7ukyGaHiReOa7s90jAoGAJc0K7smn98YutQQ+g2ur +ybfnsl0YdsksaP2S2zvZUmNevKPrgnaIDDabOlhYYga+AK1G3FQ7/nefUgiIg1Nd +kr+T6Q4osS3xHB6Az9p/jaF4R2KaWN2nNVCn7ecsmPxDdM7k1vLxaT26vwO9OP5f +YU/5CeIzrfA5nQyPZkOXZBk= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/sdk/azidentity/testdata/certificate_nokey.pem b/sdk/azidentity/testdata/certificate_nokey.pem new file mode 100644 index 000000000000..465db0813cef --- /dev/null +++ b/sdk/azidentity/testdata/certificate_nokey.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go new file mode 100644 index 000000000000..155a837644a7 --- /dev/null +++ b/sdk/azidentity/username_password_credential.go @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// UsernamePasswordCredential enables authentication to Azure Active Directory using a user's username and password. If the user has MFA enabled this +// credential will fail to get a token returning an AuthenticationFailureError. Also, this credential requires a high degree of trust and is not +// recommended outside of prototyping when more secure credentials can be used. +type UsernamePasswordCredential struct { + azcore.TokenCredential + client *aadIdentityClient + tenantID string // Gets the Azure Active Directory tenant (directory) ID of the service principal + clientID string // Gets the client (application) ID of the service principal + username string // Gets the user account's user name + password string // Gets the user account's password +} + +// NewUsernamePasswordCredential constructs a new UsernamePasswordCredential with the details needed to authenticate against Azure Active Directory with +// a simple username and password. +// tenantID: The Azure Active Directory tenant (directory) ID of the service principal. +// clientID: The client (application) ID of the service principal. +// username: A user's account username +// password: A user's account password +// options: TokenCredentialOptions used to configure the pipeline for the requests sent to Azure Active Directory. +func NewUsernamePasswordCredential(tenantID string, clientID string, username string, password string, options *TokenCredentialOptions) (*UsernamePasswordCredential, error) { + c, err := newAADIdentityClient(options) + if err != nil { + return nil, err + } + return &UsernamePasswordCredential{tenantID: tenantID, clientID: clientID, username: username, password: password, client: c}, nil +} + +// GetToken obtains a token from Azure Active Directory using the specified username and password. +// scopes: The list of scopes for which the token will have access. +// ctx: The context used to control the request lifetime. +// Returns an AccessToken which can be used to authenticate service client calls. +func (c *UsernamePasswordCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) { + tk, err := c.client.authenticateUsernamePassword(ctx, c.tenantID, c.clientID, c.username, c.password, opts.Scopes) + if err != nil { + addGetTokenFailureLogs("Username Password Credential", err) + return nil, err + } + azcore.Log().Write(LogCredential, logGetTokenSuccess(c, opts)) + return tk, err +} + +// AuthenticationPolicy implements the azcore.Credential interface on UsernamePasswordCredential. +func (c *UsernamePasswordCredential) AuthenticationPolicy(options azcore.AuthenticationPolicyOptions) azcore.Policy { + return newBearerTokenPolicy(c, options) +} diff --git a/sdk/azidentity/username_password_credential_test.go b/sdk/azidentity/username_password_credential_test.go new file mode 100644 index 000000000000..6fd36e547108 --- /dev/null +++ b/sdk/azidentity/username_password_credential_test.go @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestUsernamePasswordCredential_CreateAuthRequestSuccess(t *testing.T) { + cred, err := NewUsernamePasswordCredential(tenantID, clientID, "username", "password", nil) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + req, err := cred.client.createUsernamePasswordAuthRequest(cred.tenantID, cred.clientID, cred.username, cred.password, []string{scope}) + if err != nil { + t.Fatalf("Unexpectedly received an error: %v", err) + } + if req.Request.Header.Get(azcore.HeaderContentType) != azcore.HeaderURLEncoded { + t.Fatalf("Unexpected value for Content-Type header") + } + body, err := ioutil.ReadAll(req.Request.Body) + if err != nil { + t.Fatalf("Unable to read request body") + } + bodyStr := string(body) + reqQueryParams, err := url.ParseQuery(bodyStr) + if err != nil { + t.Fatalf("Unable to parse query params in request") + } + if reqQueryParams[qpResponseType][0] != "token" { + t.Fatalf("Unexpected response type") + } + if reqQueryParams[qpGrantType][0] != "password" { + t.Fatalf("Unexpected grant type") + } + if reqQueryParams[qpClientID][0] != clientID { + t.Fatalf("Unexpected client ID in the client_id header") + } + if reqQueryParams[qpUsername][0] != "username" { + t.Fatalf("Unexpected username in the username header") + } + if reqQueryParams[qpPassword][0] != "password" { + t.Fatalf("Unexpected password in the password header") + } + if reqQueryParams[qpScope][0] != scope { + t.Fatalf("Unexpected scope in scope header") + } + if req.Request.URL.Host != defaultTestAuthorityHost { + t.Fatalf("Unexpected default authority host") + } + if req.Request.URL.Scheme != "https" { + t.Fatalf("Wrong request scheme") + } +} + +func TestUsernamePasswordCredential_GetTokenSuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srvURL := srv.URL() + cred, err := NewUsernamePasswordCredential(tenantID, clientID, "username", "password", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Expected an empty error but received: %s", err.Error()) + } +} + +func TestUsernamePasswordCredential_GetTokenInvalidCredentials(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized)) + srvURL := srv.URL() + cred, err := NewUsernamePasswordCredential(tenantID, clientID, "username", "wrong_password", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{scope}}) + if err == nil { + t.Fatalf("Expected an error but did not receive one.") + } +} + +func TestBearerPolicy_UsernamePasswordCredential(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srvURL := srv.URL() + cred, err := NewUsernamePasswordCredential(tenantID, clientID, "username", "password", &TokenCredentialOptions{HTTPClient: srv, AuthorityHost: &srvURL}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewTelemetryPolicy(azcore.TelemetryOptions{}), + azcore.NewUniqueRequestIDPolicy(), + azcore.NewRetryPolicy(nil), + cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + azcore.NewRequestLogPolicy(azcore.RequestLogOptions{})) + _, err = pipeline.Do(context.Background(), azcore.NewRequest(http.MethodGet, srv.URL())) + if err != nil { + t.Fatalf("Expected an empty error but receive: %v", err) + } +} From f65a74687ef7be4c58f232b9d6dc49e2bf86ae31 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 22 Jul 2020 19:12:15 -0700 Subject: [PATCH 2/2] update dependencies --- sdk/azidentity/go.mod | 4 ++-- sdk/azidentity/go.sum | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sdk/azidentity/go.mod b/sdk/azidentity/go.mod index ec84ba11f6ce..29e3ad3eb3c6 100644 --- a/sdk/azidentity/go.mod +++ b/sdk/azidentity/go.mod @@ -3,6 +3,6 @@ module github.com/Azure/azure-sdk-for-go/sdk/azidentity go 1.13 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0 - github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.1 + github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.2 ) diff --git a/sdk/azidentity/go.sum b/sdk/azidentity/go.sum index 00dfdfebed95..f4a13bae9f20 100644 --- a/sdk/azidentity/go.sum +++ b/sdk/azidentity/go.sum @@ -1,4 +1,6 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0 h1:VdhfbVpQ3dkhXYOx/Wj1+utikcZkZSZSmpqmXWwaNJY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.0/go.mod h1:hL9TGc07RkJVzDIBxsYXC/r0M+YiRkvl4z1elXCD+8s= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0 h1:cLpVMIkXC/umSP9DMz9I6FttDWJAsmvhpaB6MlkagGY= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.0/go.mod h1:Q+TCQnSr+clUU0JU+xrHZ3slYCxw17AOFdvWFpQXjAY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.1 h1:f50d2lvCzW+yaZqWLzd1kU5808BeDBGdClUQybSzSVU= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.9.1/go.mod h1:fBbm1JLvufiabxBiiZWThNODf8+bARgZ81aP3CEx3sg= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.1 h1:xY9/wUJ8PcxmTEJ6z+0qKuj9rb3Aw9nhiL+ik5evR/g= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.1/go.mod h1:Q+TCQnSr+clUU0JU+xrHZ3slYCxw17AOFdvWFpQXjAY= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.2 h1:d1hG+ChFZNyblEulXP3unkwzUmh83grtG3t4sMV+6Xg= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.2.2/go.mod h1:Q+TCQnSr+clUU0JU+xrHZ3slYCxw17AOFdvWFpQXjAY=