diff --git a/authorize_helper_test.go b/authorize_helper_test.go index 30393093a..44c803310 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -75,22 +75,22 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"wta://auth"}}, - url: "wta://auth", + client: &DefaultClient{RedirectURIs: []string{"wta://auth"}}, + url: "wta://auth", expected: "wta://auth", - isError: false, + isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"wta:///auth"}}, - url: "wta:///auth", + client: &DefaultClient{RedirectURIs: []string{"wta:///auth"}}, + url: "wta:///auth", expected: "wta:///auth", - isError: false, + isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"wta://foo/auth"}}, - url: "wta://foo/auth", + client: &DefaultClient{RedirectURIs: []string{"wta://foo/auth"}}, + url: "wta://foo/auth", expected: "wta://foo/auth", - isError: false, + isError: false, }, { client: &DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, @@ -131,10 +131,10 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { } func TestIsRedirectURISecure(t *testing.T) { - for d, c := range []struct{ - u string + for d, c := range []struct { + u string err bool - } { + }{ {u: "http://google.com", err: true}, {u: "https://google.com", err: false}, {u: "http://localhost", err: false}, @@ -144,4 +144,4 @@ func TestIsRedirectURISecure(t *testing.T) { require.Nil(t, err) assert.Equal(t, !c.err, IsRedirectURISecure(uu), "case %d", d) } -} \ No newline at end of file +} diff --git a/compose/compose.go b/compose/compose.go index 4ab8b3ee2..c3fe6ef02 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -6,7 +6,7 @@ import ( "github.com/ory-am/fosite" ) -type handler func(config *Config, storage interface{}, strategy interface{}) interface{} +type Factory func(config *Config, storage interface{}, strategy interface{}) interface{} // Compose takes a config, a storage, a strategy and handlers to instantiate an OAuth2Provider: // @@ -30,7 +30,7 @@ type handler func(config *Config, storage interface{}, strategy interface{}) int // ) // // Compose makes use of interface{} types in order to be able to handle a all types of stores, strategies and handlers. -func Compose(config *Config, storage interface{}, strategy interface{}, handlers ...handler) fosite.OAuth2Provider { +func Compose(config *Config, storage interface{}, strategy interface{}, factories ...Factory) fosite.OAuth2Provider { f := &fosite.Fosite{ Store: storage.(fosite.Storage), AuthorizeEndpointHandlers: fosite.AuthorizeEndpointHandlers{}, @@ -41,8 +41,8 @@ func Compose(config *Config, storage interface{}, strategy interface{}, handlers ScopeStrategy: fosite.HierarchicScopeStrategy, } - for _, h := range handlers { - res := h(config, storage, strategy) + for _, factory := range factories { + res := factory(config, storage, strategy) if ah, ok := res.(fosite.AuthorizeEndpointHandler); ok { f.AuthorizeEndpointHandlers.Append(ah) } diff --git a/compose/compose_oauth2.go b/compose/compose_oauth2.go index eed09d598..9441be97f 100644 --- a/compose/compose_oauth2.go +++ b/compose/compose_oauth2.go @@ -87,3 +87,17 @@ func OAuth2TokenIntrospectionFactory(config *Config, storage interface{}, strate ScopeStrategy: fosite.HierarchicScopeStrategy, } } + +// OAuth2StatelessJWTIntrospectionFactory creates an OAuth2 token introspection handler and +// registers an access token validator. This can only be used to validate JWTs and does so +// statelessly, meaning it uses only the data available in the JWT itself, and does not access the +// storage implementation at all. +// +// Due to the stateless nature of this factory, THE BUILT-IN REVOCATION MECHANISMS WILL NOT WORK. +// If you need revocation, you can validate JWTs statefully, using the other factories. +func OAuth2StatelessJWTIntrospectionFactory(config *Config, storage interface{}, strategy interface{}) interface{} { + return &oauth2.StatelessJWTValidator{ + JWTAccessTokenStrategy: strategy.(oauth2.JWTAccessTokenStrategy), + ScopeStrategy: fosite.HierarchicScopeStrategy, + } +} diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go new file mode 100644 index 000000000..d49410667 --- /dev/null +++ b/handler/oauth2/introspector_jwt.go @@ -0,0 +1,43 @@ +package oauth2 + +import ( + "github.com/ory-am/fosite" + "github.com/pkg/errors" + "golang.org/x/net/context" +) + +type JWTAccessTokenStrategy interface { + AccessTokenStrategy + JWTStrategy +} + +type StatelessJWTValidator struct { + JWTAccessTokenStrategy + ScopeStrategy fosite.ScopeStrategy +} + +func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token string, tokenType fosite.TokenType, accessRequest fosite.AccessRequester, scopes []string) (err error) { + or, err := v.JWTAccessTokenStrategy.ValidateJWT(fosite.AccessToken, token) + if err != nil { + return err + } + + for _, scope := range scopes { + if scope == "" { + continue + } + + if !v.ScopeStrategy(or.GetGrantedScopes(), scope) { + return errors.WithStack(fosite.ErrInvalidScope) + } + } + + accessRequest.Merge(or) + return nil +} + +// Revocation is not supported with the stateless validator. If you need revocation, use the +// CoreValidator struct instead. +func (v *StatelessJWTValidator) RevokeToken(ctx context.Context, token string, tokenType fosite.TokenType) error { + return errors.Wrap(fosite.ErrMisconfiguration, "Token revocation is not supported") +} diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go new file mode 100644 index 000000000..f35aaba84 --- /dev/null +++ b/handler/oauth2/introspector_jwt_test.go @@ -0,0 +1,127 @@ +package oauth2 + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/ory-am/fosite" + "github.com/ory-am/fosite/internal" + "github.com/ory-am/fosite/token/jwt" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestIntrospectJWT(t *testing.T) { + strat := &RS256JWTStrategy{ + RS256JWTStrategy: &jwt.RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, + } + + v := &StatelessJWTValidator{ + JWTAccessTokenStrategy: strat, + ScopeStrategy: fosite.HierarchicScopeStrategy, + } + + for k, c := range []struct { + description string + token func() string + expectErr error + scopes []string + }{ + { + description: "should fail because jwt is expired", + token: func() string { + jwt := jwtExpiredCase(fosite.AccessToken) + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(t, err) + return token + }, + expectErr: fosite.ErrTokenExpired, + }, + { + description: "should pass because scope was granted", + token: func() string { + jwt := jwtValidCase(fosite.AccessToken) + jwt.GrantedScopes = []string{"foo", "bar"} + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(t, err) + return token + }, + scopes: []string{"foo"}, + }, + { + description: "should fail because scope was not granted", + token: func() string { + jwt := jwtValidCase(fosite.AccessToken) + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(t, err) + return token + }, + scopes: []string{"foo"}, + expectErr: fosite.ErrInvalidScope, + }, + { + description: "should fail because signature is invalid", + token: func() string { + jwt := jwtValidCase(fosite.AccessToken) + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(t, err) + parts := strings.Split(token, ".") + dec, err := base64.RawURLEncoding.DecodeString(parts[1]) + assert.NoError(t, err) + s := strings.Replace(string(dec), "peter", "piper", -1) + parts[1] = base64.RawURLEncoding.EncodeToString([]byte(s)) + return strings.Join(parts, ".") + }, + expectErr: fosite.ErrTokenSignatureMismatch, + }, + { + description: "should pass", + token: func() string { + jwt := jwtValidCase(fosite.AccessToken) + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(t, err) + return token + }, + }, + } { + if c.scopes == nil { + c.scopes = []string{} + } + areq := fosite.NewAccessRequest(nil) + err := v.IntrospectToken(nil, c.token(), fosite.AccessToken, areq, c.scopes) + + assert.True(t, errors.Cause(err) == c.expectErr, "(%d) %s\n%s\n%s", k, c.description, err, c.expectErr) + + if err == nil { + assert.Equal(t, "peter", areq.Session.GetSubject()) + } + + t.Logf("Passed test case %d", k) + } +} + +func BenchmarkIntrospectJWT(b *testing.B) { + strat := &RS256JWTStrategy{ + RS256JWTStrategy: &jwt.RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, + } + + v := &StatelessJWTValidator{ + JWTAccessTokenStrategy: strat, + } + + jwt := jwtValidCase(fosite.AccessToken) + token, _, err := strat.GenerateAccessToken(nil, jwt) + assert.NoError(b, err) + areq := fosite.NewAccessRequest(nil) + + for n := 0; n < b.N; n++ { + err = v.IntrospectToken(nil, token, fosite.AccessToken, areq, []string{}) + } + + assert.NoError(b, err) +} diff --git a/handler/oauth2/strategy.go b/handler/oauth2/strategy.go index 0822aac1a..6ee833465 100644 --- a/handler/oauth2/strategy.go +++ b/handler/oauth2/strategy.go @@ -11,6 +11,10 @@ type CoreStrategy interface { AuthorizeCodeStrategy } +type JWTStrategy interface { + ValidateJWT(tokenType fosite.TokenType, token string) (requester fosite.Requester, err error) +} + type AccessTokenStrategy interface { AccessTokenSignature(token string) string GenerateAccessToken(ctx context.Context, requester fosite.Requester) (token string, signature string, err error) diff --git a/handler/oauth2/strategy_jwt.go b/handler/oauth2/strategy_jwt.go index 2ead9cd2d..8c5d3a6c0 100644 --- a/handler/oauth2/strategy_jwt.go +++ b/handler/oauth2/strategy_jwt.go @@ -2,6 +2,7 @@ package oauth2 import ( "strings" + "time" jwtx "github.com/dgrijalva/jwt-go" "github.com/ory-am/fosite" @@ -13,6 +14,7 @@ import ( // RS256JWTStrategy is a JWT RS256 strategy. type RS256JWTStrategy struct { *jwt.RS256JWTStrategy + Issuer string } func (h RS256JWTStrategy) signature(token string) string { @@ -36,12 +38,42 @@ func (h RS256JWTStrategy) AuthorizeCodeSignature(token string) string { return h.signature(token) } +func (h *RS256JWTStrategy) ValidateJWT(tokenType fosite.TokenType, token string) (requester fosite.Requester, err error) { + t, err := h.validate(token) + if err != nil { + return nil, err + } + + claims := jwt.JWTClaims{} + claims.FromMapClaims(t.Claims.(jwtx.MapClaims)) + + requester = &fosite.Request{ + Client: &fosite.DefaultClient{}, + RequestedAt: claims.IssuedAt, + Session: &JWTSession{ + JWTClaims: &claims, + JWTHeader: &jwt.Headers{ + Extra: make(map[string]interface{}), + }, + ExpiresAt: map[fosite.TokenType]time.Time{ + tokenType: claims.ExpiresAt, + }, + Subject: claims.Subject, + }, + Scopes: claims.Scope, + GrantedScopes: claims.Scope, + } + + return +} + func (h *RS256JWTStrategy) GenerateAccessToken(_ context.Context, requester fosite.Requester) (token string, signature string, err error) { return h.generate(fosite.AccessToken, requester) } func (h *RS256JWTStrategy) ValidateAccessToken(_ context.Context, _ fosite.Requester, token string) error { - return h.validate(token) + _, err := h.validate(token) + return err } func (h *RS256JWTStrategy) GenerateRefreshToken(_ context.Context, requester fosite.Requester) (token string, signature string, err error) { @@ -49,7 +81,8 @@ func (h *RS256JWTStrategy) GenerateRefreshToken(_ context.Context, requester fos } func (h *RS256JWTStrategy) ValidateRefreshToken(_ context.Context, _ fosite.Requester, token string) error { - return h.validate(token) + _, err := h.validate(token) + return err } func (h *RS256JWTStrategy) GenerateAuthorizeCode(_ context.Context, requester fosite.Requester) (token string, signature string, err error) { @@ -57,45 +90,47 @@ func (h *RS256JWTStrategy) GenerateAuthorizeCode(_ context.Context, requester fo } func (h *RS256JWTStrategy) ValidateAuthorizeCode(_ context.Context, requester fosite.Requester, token string) error { - return h.validate(token) + _, err := h.validate(token) + return err } -func (h *RS256JWTStrategy) validate(token string) error { - t, err := h.RS256JWTStrategy.Decode(token) - if err != nil { - return err +func (h *RS256JWTStrategy) validate(token string) (t *jwtx.Token, err error) { + t, err = h.RS256JWTStrategy.Decode(token) + + if err == nil { + err = t.Claims.Valid() } - // validate the token - if err = t.Claims.Valid(); err != nil { - if e, ok := err.(*jwtx.ValidationError); ok { + if err != nil { + if e, ok := errors.Cause(err).(*jwtx.ValidationError); ok { switch e.Errors { case jwtx.ValidationErrorMalformed: - return errors.Wrap(fosite.ErrInvalidTokenFormat, err.Error()) + err = errors.Wrap(fosite.ErrInvalidTokenFormat, err.Error()) case jwtx.ValidationErrorUnverifiable: - return errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error()) + err = errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error()) case jwtx.ValidationErrorSignatureInvalid: - return errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error()) + err = errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error()) case jwtx.ValidationErrorAudience: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) case jwtx.ValidationErrorExpired: - return errors.Wrap(fosite.ErrTokenExpired, err.Error()) + err = errors.Wrap(fosite.ErrTokenExpired, err.Error()) case jwtx.ValidationErrorIssuedAt: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) case jwtx.ValidationErrorIssuer: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) case jwtx.ValidationErrorNotValidYet: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) case jwtx.ValidationErrorId: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) case jwtx.ValidationErrorClaimsInvalid: - return errors.Wrap(fosite.ErrTokenClaim, err.Error()) + err = errors.Wrap(fosite.ErrTokenClaim, err.Error()) + default: + err = errors.Wrap(fosite.ErrRequestUnauthorized, err.Error()) } - return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error()) } } - return nil + return } func (h *RS256JWTStrategy) generate(tokenType fosite.TokenType, requester fosite.Requester) (string, string, error) { @@ -106,6 +141,17 @@ func (h *RS256JWTStrategy) generate(tokenType fosite.TokenType, requester fosite } else { claims := jwtSession.GetJWTClaims() claims.ExpiresAt = jwtSession.GetExpiresAt(tokenType) + + if claims.IssuedAt.IsZero() { + claims.IssuedAt = time.Now() + } + + if claims.Issuer == "" { + claims.Issuer = h.Issuer + } + + claims.Scope = requester.GetGrantedScopes() + return h.RS256JWTStrategy.Generate(claims.ToMapClaims(), jwtSession.GetJWTHeader()) } } diff --git a/integration/helper_endpoints_test.go b/integration/helper_endpoints_test.go index ffae03a98..7cd72bdf2 100644 --- a/integration/helper_endpoints_test.go +++ b/integration/helper_endpoints_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ory-am/fosite" + "github.com/ory-am/fosite/handler/oauth2" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -113,17 +114,17 @@ func authCallbackHandler(t *testing.T) func(rw http.ResponseWriter, req *http.Re } } -func tokenEndpointHandler(t *testing.T, oauth2 fosite.OAuth2Provider) func(rw http.ResponseWriter, req *http.Request) { +func tokenEndpointHandler(t *testing.T, provider fosite.OAuth2Provider) func(rw http.ResponseWriter, req *http.Request) { return func(rw http.ResponseWriter, req *http.Request) { req.ParseForm() ctx := fosite.NewContext() - accessRequest, err := oauth2.NewAccessRequest(ctx, req, &fosite.DefaultSession{}) + accessRequest, err := provider.NewAccessRequest(ctx, req, &oauth2.JWTSession{}) if err != nil { t.Logf("Access request failed because %s.", err.Error()) t.Logf("Request: %s.", accessRequest) t.Logf("Stack: %v.", err.(stackTracer).StackTrace()) - oauth2.WriteAccessError(rw, accessRequest, err) + provider.WriteAccessError(rw, accessRequest, err) return } @@ -131,15 +132,15 @@ func tokenEndpointHandler(t *testing.T, oauth2 fosite.OAuth2Provider) func(rw ht accessRequest.GrantScope("fosite") } - response, err := oauth2.NewAccessResponse(ctx, req, accessRequest) + response, err := provider.NewAccessResponse(ctx, req, accessRequest) if err != nil { t.Logf("Access request failed because %s.", err.Error()) t.Logf("Request: %s.", accessRequest) t.Logf("Stack: %v.", err.(stackTracer).StackTrace()) - oauth2.WriteAccessError(rw, accessRequest, err) + provider.WriteAccessError(rw, accessRequest, err) return } - oauth2.WriteAccessResponse(rw, accessRequest, response) + provider.WriteAccessResponse(rw, accessRequest, response) } } diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index 56de09756..641b523f4 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -9,8 +9,10 @@ import ( "github.com/ory-am/fosite" "github.com/ory-am/fosite/handler/oauth2" "github.com/ory-am/fosite/handler/openid" + "github.com/ory-am/fosite/internal" "github.com/ory-am/fosite/storage" "github.com/ory-am/fosite/token/hmac" + "github.com/ory-am/fosite/token/jwt" goauth "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) @@ -79,6 +81,12 @@ var hmacStrategy = &oauth2.HMACSHAStrategy{ AuthorizeCodeLifespan: authCodeLifespan, } +var jwtStrategy = &oauth2.RS256JWTStrategy{ + RS256JWTStrategy: &jwt.RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, +} + func mockServer(t *testing.T, f fosite.OAuth2Provider, session fosite.Session) *httptest.Server { router := mux.NewRouter() router.HandleFunc("/auth", authEndpointHandler(t, f, session)) diff --git a/integration/introspect_token_test.go b/integration/introspect_token_test.go index e64113b55..581f224f0 100644 --- a/integration/introspect_token_test.go +++ b/integration/introspect_token_test.go @@ -14,15 +14,34 @@ import ( ) func TestIntrospectToken(t *testing.T) { - for _, strategy := range []oauth2.AccessTokenStrategy{ - hmacStrategy, + for _, c := range []struct { + description string + strategy oauth2.AccessTokenStrategy + factory compose.Factory + }{ + { + description: "HMAC strategy with OAuth2TokenIntrospectionFactory", + strategy: hmacStrategy, + factory: compose.OAuth2TokenIntrospectionFactory, + }, + { + description: "JWT strategy with OAuth2TokenIntrospectionFactory", + strategy: jwtStrategy, + factory: compose.OAuth2TokenIntrospectionFactory, + }, + { + description: "JWT strategy with OAuth2StatelessJWTIntrospectionFactory", + strategy: jwtStrategy, + factory: compose.OAuth2StatelessJWTIntrospectionFactory, + }, } { - runIntrospectTokenTest(t, strategy) + t.Logf("testing %v", c.description) + runIntrospectTokenTest(t, c.strategy, c.factory) } } -func runIntrospectTokenTest(t *testing.T, strategy oauth2.AccessTokenStrategy) { - f := compose.Compose(new(compose.Config), fositeStore, strategy, compose.OAuth2ClientCredentialsGrantFactory, compose.OAuth2TokenIntrospectionFactory) +func runIntrospectTokenTest(t *testing.T, strategy oauth2.AccessTokenStrategy, introspectionFactory compose.Factory) { + f := compose.Compose(new(compose.Config), fositeStore, strategy, compose.OAuth2ClientCredentialsGrantFactory, introspectionFactory) ts := mockServer(t, f, &fosite.DefaultSession{}) defer ts.Close() @@ -74,7 +93,11 @@ func runIntrospectTokenTest(t *testing.T, strategy oauth2.AccessTokenStrategy) { }, } { res := struct { - Active bool `json:"active"` + Active bool `json:"active"` + ClientId string `json:"client_id"` + Scope string `json:"scope"` + ExpiresAt float64 `json:"exp"` + IssuedAt float64 `json:"iat"` }{} s := gorequest.New() s = s.Post(ts.URL + "/introspect"). @@ -86,5 +109,11 @@ func runIntrospectTokenTest(t *testing.T, strategy oauth2.AccessTokenStrategy) { t.Logf("Got answer: %s", bytes) assert.Len(t, errs, 0) assert.Equal(t, c.isActive, res.Active) + if c.isActive { + assert.Equal(t, "fosite", res.Scope) + assert.True(t, res.ExpiresAt > 0) + assert.True(t, res.IssuedAt > 0) + assert.True(t, res.IssuedAt < res.ExpiresAt) + } } } diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 3f4922be1..e0518f6c7 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -16,6 +16,7 @@ type JWTClaims struct { IssuedAt time.Time NotBefore time.Time ExpiresAt time.Time + Scope []string Extra map[string]interface{} } @@ -31,12 +32,84 @@ func (c *JWTClaims) ToMap() map[string]interface{} { ret["sub"] = c.Subject ret["iss"] = c.Issuer ret["aud"] = c.Audience - ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype - ret["nbf"] = float64(c.NotBefore.Unix()) // jwt-go does not support int64 as datatype + + if !c.IssuedAt.IsZero() { + ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype + } + + if !c.NotBefore.IsZero() { + ret["nbf"] = float64(c.NotBefore.Unix()) // jwt-go does not support int64 as datatype + } + ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype + + if c.Scope != nil { + ret["scp"] = c.Scope + } + return ret } +// FromMap will set the claims based on a mapping +func (c *JWTClaims) FromMap(m map[string]interface{}) { + c.Extra = make(map[string]interface{}) + for k, v := range m { + switch k { + case "jti": + if s, ok := v.(string); ok { + c.JTI = s + } + case "sub": + if s, ok := v.(string); ok { + c.Subject = s + } + case "iss": + if s, ok := v.(string); ok { + c.Issuer = s + } + case "aud": + if s, ok := v.(string); ok { + c.Audience = s + } + case "iat": + switch v.(type) { + case float64: + c.IssuedAt = time.Unix(int64(v.(float64)), 0) + case int64: + c.IssuedAt = time.Unix(v.(int64), 0) + } + case "nbf": + switch v.(type) { + case float64: + c.NotBefore = time.Unix(int64(v.(float64)), 0) + case int64: + c.NotBefore = time.Unix(v.(int64), 0) + } + case "exp": + switch v.(type) { + case float64: + c.ExpiresAt = time.Unix(int64(v.(float64)), 0) + case int64: + c.ExpiresAt = time.Unix(v.(int64), 0) + } + case "scp": + switch v.(type) { + case []string: + c.Scope = v.([]string) + case []interface{}: + c.Scope = make([]string, len(v.([]interface{}))) + for i, vi := range v.([]interface{}) { + if s, ok := vi.(string); ok { + c.Scope[i] = s + } + } + } + default: + c.Extra[k] = v + } + } +} + // Add will add a key-value pair to the extra field func (c *JWTClaims) Add(key string, value interface{}) { if c.Extra == nil { @@ -54,3 +127,8 @@ func (c JWTClaims) Get(key string) interface{} { func (c JWTClaims) ToMapClaims() jwt.MapClaims { return c.ToMap() } + +// FromMapClaims will populate claims from a jwt-go MapClaims representaion +func (c *JWTClaims) FromMapClaims(mc jwt.MapClaims) { + c.FromMap(mc) +} diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index 1dc3259fb..c9ba09df0 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -16,12 +16,26 @@ var jwtClaims = &JWTClaims{ Audience: "tests", ExpiresAt: time.Now().Add(time.Hour).Round(time.Second), JTI: "abcdef", + Scope: []string{"email", "offline"}, Extra: map[string]interface{}{ "foo": "bar", "baz": "bar", }, } +var jwtClaimsMap = map[string]interface{}{ + "sub": jwtClaims.Subject, + "iat": float64(jwtClaims.IssuedAt.Unix()), + "iss": jwtClaims.Issuer, + "nbf": float64(jwtClaims.NotBefore.Unix()), + "aud": jwtClaims.Audience, + "exp": float64(jwtClaims.ExpiresAt.Unix()), + "jti": jwtClaims.JTI, + "scp": []string{"email", "offline"}, + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], +} + func TestClaimAddGetString(t *testing.T) { jwtClaims.Add("foo", "bar") assert.Equal(t, "bar", jwtClaims.Get("foo")) @@ -45,15 +59,11 @@ func TestAssert(t *testing.T) { } func TestClaimsToMap(t *testing.T) { - assert.Equal(t, map[string]interface{}{ - "sub": jwtClaims.Subject, - "iat": float64(jwtClaims.IssuedAt.Unix()), - "iss": jwtClaims.Issuer, - "nbf": float64(jwtClaims.NotBefore.Unix()), - "aud": jwtClaims.Audience, - "exp": float64(jwtClaims.ExpiresAt.Unix()), - "jti": jwtClaims.JTI, - "foo": jwtClaims.Extra["foo"], - "baz": jwtClaims.Extra["baz"], - }, jwtClaims.ToMap()) + assert.Equal(t, jwtClaimsMap, jwtClaims.ToMap()) +} + +func TestClaimsFromMap(t *testing.T) { + var claims JWTClaims + claims.FromMap(jwtClaimsMap) + assert.Equal(t, jwtClaims, &claims) } diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go index 786f7b941..319b9f5aa 100644 --- a/token/jwt/jwt.go +++ b/token/jwt/jwt.go @@ -59,7 +59,7 @@ func (j *RS256JWTStrategy) Decode(token string) (*jwt.Token, error) { }) if err != nil { - return nil, errors.Errorf("Couldn't parse token: %v", err) + return nil, errors.Wrap(err, "Couldn't parse token") } else if !parsedToken.Valid { return nil, errors.Errorf("Token is invalid") }