Skip to content

Commit

Permalink
fix: enable update jwt via callback for workloadidentity (#719)
Browse files Browse the repository at this point in the history
* fix: enable update jwt via callback for workloadidentity

* fix breaking changes and add unit test

* typo

* fix comment

* add unit test
  • Loading branch information
cvvz authored Mar 15, 2023
1 parent ee71315 commit 553a90a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 4 deletions.
1 change: 1 addition & 0 deletions autorest/adal/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/Azure/go-autorest/logger v0.2.1
github.com/Azure/go-autorest/tracing v0.6.0
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/stretchr/testify v1.8.2
golang.org/x/crypto v0.6.0
)

Expand Down
17 changes: 17 additions & 0 deletions autorest/adal/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,20 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
Expand Down Expand Up @@ -39,3 +51,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
42 changes: 38 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ type TokenRefreshCallback func(Token) error
// TokenRefresh is a type representing a custom callback to refresh a token
type TokenRefresh func(ctx context.Context, resource string) (*Token, error)

// JWTCallback is the type representing callback that will be called to get the federated OIDC JWT
type JWTCallback func() (string, error)

// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
Expand Down Expand Up @@ -367,14 +370,18 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err

// ServicePrincipalFederatedSecret implements ServicePrincipalSecret for Federated JWTs.
type ServicePrincipalFederatedSecret struct {
jwt string
jwtCallback JWTCallback
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during OAuth Token Acquisition using a JWT signed by an OIDC issuer.
func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(_ *ServicePrincipalToken, v *url.Values) error {
jwt, err := secret.jwtCallback()
if err != nil {
return err
}

v.Set("client_assertion", secret.jwt)
v.Set("client_assertion", jwt)
v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
return nil
}
Expand Down Expand Up @@ -687,6 +694,8 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie
}

// NewServicePrincipalTokenFromFederatedToken creates a ServicePrincipalToken from the supplied federated OIDC JWT.
//
// Deprecated: Use NewServicePrincipalTokenFromFederatedTokenWithCallback to refresh jwt dynamically.
func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
Expand All @@ -700,12 +709,37 @@ func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientI
if jwt == "" {
return nil, fmt.Errorf("parameter 'jwt' cannot be empty")
}
return NewServicePrincipalTokenFromFederatedTokenCallback(
oauthConfig,
clientID,
func() (string, error) {
return jwt, nil
},
resource,
callbacks...,
)
}

// NewServicePrincipalTokenFromFederatedTokenCallback creates a ServicePrincipalToken from the supplied federated OIDC JWTCallback.
func NewServicePrincipalTokenFromFederatedTokenCallback(oauthConfig OAuthConfig, clientID string, jwtCallback JWTCallback, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if jwtCallback == nil {
return nil, fmt.Errorf("parameter 'jwtCallback' cannot be empty")
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalFederatedSecret{
jwt: jwt,
jwtCallback: jwtCallback,
},
callbacks...,
)
Expand Down
65 changes: 65 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"io/ioutil"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
Expand All @@ -37,6 +39,7 @@ import (
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/mocks"
jwt "github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
)

const (
Expand Down Expand Up @@ -206,6 +209,62 @@ func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
}
}

func TestFederatedTokenRefreshUsesJwtCallback(t *testing.T) {
baseDir, err := os.MkdirTemp("", "")
assert.NoError(t, err)
jwtFile := filepath.Join(baseDir, "token")

jwtCallback := func() (string, error) {
jwt, err := os.ReadFile(jwtFile)
if err != nil {
return "", fmt.Errorf("failed to read a file with a federated token: %w", err)
}
return string(jwt), nil
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwt := r.FormValue("client_assertion")
refreshToken := r.FormValue("refresh_token")

if jwt == "aaa.aaa" {
w.Write([]byte(`{"access_token":"A","expires_in":"3600"}`))
} else if jwt == "bbb.bbb" {
w.Write([]byte(`{"access_token":"B","expires_in":"3600","refresh_token":"R"}`))
} else if refreshToken == "R" {
w.Write([]byte(`{"access_token":"C","expires_in":"3600"}`))
} else {
w.WriteHeader(http.StatusBadRequest)
}
}))

spt := newServicePrincipalTokenFederatedJwtCallback(t, jwtCallback, server.URL)

// token file does not exist, no such file error
err = spt.refreshInternal(context.Background(), "")
assert.ErrorIs(t, err, os.ErrNotExist)

// get jwt token from jwtFile
err = os.WriteFile(jwtFile, []byte("aaa.aaa"), 0600)
assert.NoError(t, err)
err = spt.refreshInternal(context.Background(), "")
assert.NoError(t, err)
assert.Equal(t, "A", spt.inner.Token.AccessToken)

// jwtFile is refreshed
err = os.WriteFile(jwtFile, []byte("bbb.bbb"), 0600)
assert.NoError(t, err)
err = spt.refreshInternal(context.Background(), "")
assert.NoError(t, err)
assert.Equal(t, "B", spt.inner.Token.AccessToken)
// refresh_token is set
assert.Equal(t, "R", spt.inner.Token.RefreshToken)

// after refresh_token is set, the callback won't be called
err = spt.refreshInternal(context.Background(), "")
assert.NoError(t, err)
assert.Equal(t, "C", spt.inner.Token.AccessToken)
}

func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()

Expand Down Expand Up @@ -1602,3 +1661,9 @@ func newServicePrincipalTokenFederatedJwt(t *testing.T) *ServicePrincipalToken {
spt, _ := NewServicePrincipalTokenFromFederatedToken(TestOAuthConfig, "id", signedString, "resource")
return spt
}

func newServicePrincipalTokenFederatedJwtCallback(t *testing.T, callback JWTCallback, fakeEndpoint string) *ServicePrincipalToken {
outhConfig, _ := NewOAuthConfig(fakeEndpoint, TestTenantID)
spt, _ := NewServicePrincipalTokenFromFederatedTokenCallback(*outhConfig, "id", callback, "resource")
return spt
}

0 comments on commit 553a90a

Please sign in to comment.