diff --git a/auth/auth_test.go b/auth/auth_test.go index 6e021c42..d15ffa21 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -5,6 +5,8 @@ import ( "os" "testing" + "golang.org/x/oauth2" + "github.com/manicminer/hamilton/auth" "github.com/manicminer/hamilton/environments" "github.com/manicminer/hamilton/internal/test" @@ -23,15 +25,16 @@ var ( ) func TestClientCertificateAuthorizerV1(t *testing.T) { - testClientCertificateAuthorizer(t, auth.TokenVersion1) + ctx := context.Background() + testClientCertificateAuthorizer(ctx, t, auth.TokenVersion1) } func TestClientCertificateAuthorizerV2(t *testing.T) { - testClientCertificateAuthorizer(t, auth.TokenVersion2) + ctx := context.Background() + testClientCertificateAuthorizer(ctx, t, auth.TokenVersion2) } -func testClientCertificateAuthorizer(t *testing.T, tokenVersion auth.TokenVersion) { - ctx := context.Background() +func testClientCertificateAuthorizer(ctx context.Context, t *testing.T, tokenVersion auth.TokenVersion) (token *oauth2.Token) { pfx := utils.Base64DecodeCertificate(clientCertificate) auth, err := auth.NewClientCertificateAuthorizer(ctx, environments.Global, auth.MsGraph, tokenVersion, tenantId, clientId, pfx, clientCertificatePath, clientCertPassword) if err != nil { @@ -40,7 +43,7 @@ func testClientCertificateAuthorizer(t *testing.T, tokenVersion auth.TokenVersio if auth == nil { t.Fatal("auth is nil, expected Authorizer") } - token, err := auth.Token() + token, err = auth.Token() if err != nil { t.Fatalf("auth.Token(): %v", err) } @@ -50,18 +53,20 @@ func testClientCertificateAuthorizer(t *testing.T, tokenVersion auth.TokenVersio if token.AccessToken == "" { t.Fatal("token.AccessToken was empty") } + return } func TestClientSecretAuthorizerV1(t *testing.T) { - testClientSecretAuthorizer(t, auth.TokenVersion1) + ctx := context.Background() + testClientSecretAuthorizer(ctx, t, auth.TokenVersion1) } func TestClientSecretAuthorizerV2(t *testing.T) { - testClientSecretAuthorizer(t, auth.TokenVersion2) + ctx := context.Background() + testClientSecretAuthorizer(ctx, t, auth.TokenVersion2) } -func testClientSecretAuthorizer(t *testing.T, tokenVersion auth.TokenVersion) { - ctx := context.Background() +func testClientSecretAuthorizer(ctx context.Context, t *testing.T, tokenVersion auth.TokenVersion) (token *oauth2.Token) { auth, err := auth.NewClientSecretAuthorizer(ctx, environments.Global, auth.MsGraph, tokenVersion, tenantId, clientId, clientSecret) if err != nil { t.Fatalf("NewClientSecretAuthorizer(): %v", err) @@ -69,7 +74,7 @@ func testClientSecretAuthorizer(t *testing.T, tokenVersion auth.TokenVersion) { if auth == nil { t.Fatal("auth is nil, expected Authorizer") } - token, err := auth.Token() + token, err = auth.Token() if err != nil { t.Fatalf("auth.Token(): %v", err) } @@ -79,10 +84,15 @@ func testClientSecretAuthorizer(t *testing.T, tokenVersion auth.TokenVersion) { if token.AccessToken == "" { t.Fatalf("token.AccessToken was empty") } + return } func TestAzureCliAuthorizer(t *testing.T) { ctx := context.Background() + testAzureCliAuthorizer(ctx, t) +} + +func testAzureCliAuthorizer(ctx context.Context, t *testing.T) (token *oauth2.Token) { auth, err := auth.NewAzureCliAuthorizer(ctx, auth.MsGraph, tenantId) if err != nil { t.Fatalf("NewAzureCliAuthorizer(): %v", err) @@ -90,7 +100,7 @@ func TestAzureCliAuthorizer(t *testing.T) { if auth == nil { t.Fatal("auth is nil, expected Authorizer") } - token, err := auth.Token() + token, err = auth.Token() if err != nil { t.Fatalf("auth.Token(): %v", err) } @@ -100,6 +110,7 @@ func TestAzureCliAuthorizer(t *testing.T) { if token.AccessToken == "" { t.Fatalf("token.AccessToken was empty") } + return } func TestMsiAuthorizer(t *testing.T) { diff --git a/auth/claims.go b/auth/claims.go index 9694b232..bfa1e291 100644 --- a/auth/claims.go +++ b/auth/claims.go @@ -32,7 +32,7 @@ func ParseClaims(token *oauth2.Token) (claims Claims, err error) { return } jwt := strings.Split(token.AccessToken, ".") - payload, err := base64.RawStdEncoding.DecodeString(jwt[1]) + payload, err := base64.RawURLEncoding.DecodeString(jwt[1]) if err != nil { return } diff --git a/auth/claims_test.go b/auth/claims_test.go new file mode 100644 index 00000000..e8af5ae6 --- /dev/null +++ b/auth/claims_test.go @@ -0,0 +1,59 @@ +package auth_test + +import ( + "context" + "testing" + + "github.com/manicminer/hamilton/auth" +) + +func TestParseClaims_azureCli(t *testing.T) { + ctx := context.Background() + token := testAzureCliAuthorizer(ctx, t) + claims, err := auth.ParseClaims(token) + if err != nil { + t.Fatal(err) + } + checkClaims(t, claims) +} + +func TestParseClaims_clientCertificate(t *testing.T) { + ctx := context.Background() + token := testClientCertificateAuthorizer(ctx, t, auth.TokenVersion2) + claims, err := auth.ParseClaims(token) + if err != nil { + t.Fatal(err) + } + checkClaims(t, claims) +} + +func TestParseClaims_clientSecret(t *testing.T) { + ctx := context.Background() + token := testClientSecretAuthorizer(ctx, t, auth.TokenVersion2) + claims, err := auth.ParseClaims(token) + if err != nil { + t.Fatal(err) + } + checkClaims(t, claims) +} + +func checkClaims(t *testing.T, claims auth.Claims) { + if claims.AppId == "" { + t.Fatal("claims.AppId was empty") + } + if claims.Audience == "" { + t.Fatal("claims.Audience was empty") + } + if claims.Issuer == "" { + t.Fatal("claims.Issuer was empty") + } + if len(claims.Roles) == 0 && claims.Scopes == "" { + t.Fatal("claims.Roles and claims.Scopes were empty") + } + if claims.Subject == "" { + t.Fatal("claims.Subject was empty") + } + if claims.TenantId == "" { + t.Fatal("claims.TenantId was empty") + } +}