Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Scope can now be space delimited in access tokens #482

Merged
merged 5 commits into from
Oct 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions compose/compose_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,14 @@ func NewOAuth2JWTECDSAStrategy(key *ecdsa.PrivateKey, strategy *oauth2.HMACSHASt
}
}

// Deprecated: Use NewOAuth2JWTStrategy(key, strategy).WithIssuer(issuer) instead.
func NewOAuth2JWTStrategyWithIssuer(key *rsa.PrivateKey, strategy *oauth2.HMACSHAStrategy, issuer string) *oauth2.DefaultJWTStrategy {
mitar marked this conversation as resolved.
Show resolved Hide resolved
return &oauth2.DefaultJWTStrategy{
JWTStrategy: &jwt.RS256JWTStrategy{
PrivateKey: key,
},
HMACSHAStrategy: strategy,
Issuer: issuer,
}
return NewOAuth2JWTStrategy(key, strategy).WithIssuer(issuer)
}

// Deprecated: Use NewOAuth2JWTECDSAStrategy(key, strategy).WithIssuer(issuer) instead.
func NewOAuth2JWTECDSAStrategyWithIssuer(key *ecdsa.PrivateKey, strategy *oauth2.HMACSHAStrategy, issuer string) *oauth2.DefaultJWTStrategy {
return &oauth2.DefaultJWTStrategy{
JWTStrategy: &jwt.ES256JWTStrategy{
PrivateKey: key,
},
HMACSHAStrategy: strategy,
Issuer: issuer,
}
return NewOAuth2JWTECDSAStrategy(key, strategy).WithIssuer(issuer)
}

func NewOpenIDConnectStrategy(config *Config, key *rsa.PrivateKey) *openid.DefaultStrategy {
Expand Down
18 changes: 17 additions & 1 deletion handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ type DefaultJWTStrategy struct {
jwt.JWTStrategy
HMACSHAStrategy *HMACSHAStrategy
Issuer string
ScopeField jwt.JWTScopeFieldEnum
}

func (h *DefaultJWTStrategy) WithIssuer(issuer string) *DefaultJWTStrategy {
h.Issuer = issuer
return h
}

func (h *DefaultJWTStrategy) WithScopeField(scopeField jwt.JWTScopeFieldEnum) *DefaultJWTStrategy {
h.ScopeField = scopeField
return h
}

func (h DefaultJWTStrategy) signature(token string) string {
Expand Down Expand Up @@ -68,7 +79,9 @@ func (h *DefaultJWTStrategy) ValidateJWT(ctx context.Context, tokenType fosite.T
return nil, err
}

claims := jwt.JWTClaims{}
claims := jwt.JWTClaims{
ScopeField: h.ScopeField,
}
claims.FromMapClaims(t.Claims.(jwtx.MapClaims))

requester = &fosite.Request{
Expand Down Expand Up @@ -170,6 +183,9 @@ func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.Toke
WithDefaults(
time.Now().UTC(),
h.Issuer,
).
WithScopeField(
h.ScopeField,
)

return h.JWTStrategy.Generate(ctx, claims.ToMapClaims(), jwtSession.GetJWTHeader())
Expand Down
119 changes: 80 additions & 39 deletions handler/oauth2/strategy_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
package oauth2

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"testing"
Expand All @@ -45,7 +47,7 @@ var j = &DefaultJWTStrategy{
// left empty to ensure it is pulled from the session's ExpiresAt map for
// the given fosite.TokenType.
var jwtValidCase = func(tokenType fosite.TokenType) *fosite.Request {
return &fosite.Request{
r := &fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Expand All @@ -66,10 +68,14 @@ var jwtValidCase = func(tokenType fosite.TokenType) *fosite.Request {
},
},
}
r.SetRequestedScopes([]string{"email", "offline"})
r.GrantScope("email")
r.GrantScope("offline")
return r
}

var jwtValidCaseWithZeroRefreshExpiry = func(tokenType fosite.TokenType) *fosite.Request {
return &fosite.Request{
r := &fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Expand All @@ -91,10 +97,14 @@ var jwtValidCaseWithZeroRefreshExpiry = func(tokenType fosite.TokenType) *fosite
},
},
}
r.SetRequestedScopes([]string{"email", "offline"})
r.GrantScope("email")
r.GrantScope("offline")
return r
}

var jwtValidCaseWithRefreshExpiry = func(tokenType fosite.TokenType) *fosite.Request {
return &fosite.Request{
r := &fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Expand All @@ -116,13 +126,17 @@ var jwtValidCaseWithRefreshExpiry = func(tokenType fosite.TokenType) *fosite.Req
},
},
}
r.SetRequestedScopes([]string{"email", "offline"})
r.GrantScope("email")
r.GrantScope("offline")
return r
}

// returns an expired JWT type. The JWTClaims.ExpiresAt time is intentionally
// left empty to ensure it is pulled from the session's ExpiresAt map for
// the given fosite.TokenType.
var jwtExpiredCase = func(tokenType fosite.TokenType) *fosite.Request {
return &fosite.Request{
r := &fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Expand All @@ -144,46 +158,73 @@ var jwtExpiredCase = func(tokenType fosite.TokenType) *fosite.Request {
},
},
}
r.SetRequestedScopes([]string{"email", "offline"})
r.GrantScope("email")
r.GrantScope("offline")
return r
}

func TestAccessToken(t *testing.T) {
for k, c := range []struct {
r *fosite.Request
pass bool
}{
{
r: jwtValidCase(fosite.AccessToken),
pass: true,
},
{
r: jwtExpiredCase(fosite.AccessToken),
pass: false,
},
{
r: jwtValidCaseWithZeroRefreshExpiry(fosite.AccessToken),
pass: true,
},
{
r: jwtValidCaseWithRefreshExpiry(fosite.AccessToken),
pass: true,
},
for s, scopeField := range []jwt.JWTScopeFieldEnum{
jwt.JWTScopeFieldList,
jwt.JWTScopeFieldString,
jwt.JWTScopeFieldBoth,
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
token, signature, err := j.GenerateAccessToken(nil, c.r)
assert.NoError(t, err)
for k, c := range []struct {
r *fosite.Request
pass bool
}{
{
r: jwtValidCase(fosite.AccessToken),
pass: true,
},
{
r: jwtExpiredCase(fosite.AccessToken),
pass: false,
},
{
r: jwtValidCaseWithZeroRefreshExpiry(fosite.AccessToken),
pass: true,
},
{
r: jwtValidCaseWithRefreshExpiry(fosite.AccessToken),
pass: true,
},
} {
t.Run(fmt.Sprintf("case=%d/%d", s, k), func(t *testing.T) {
jWithField := j.WithScopeField(scopeField)
token, signature, err := jWithField.GenerateAccessToken(nil, c.r)
assert.NoError(t, err)

parts := strings.Split(token, ".")
require.Len(t, parts, 3, "%s - %v", token, parts)
assert.Equal(t, parts[2], signature)
parts := strings.Split(token, ".")
require.Len(t, parts, 3, "%s - %v", token, parts)
assert.Equal(t, parts[2], signature)

validate := j.signature(token)
err = j.ValidateAccessToken(nil, c.r, token)
if c.pass {
assert.NoError(t, err)
assert.Equal(t, signature, validate)
} else {
assert.Error(t, err)
}
})
rawPayload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)
var payload map[string]interface{}
err = json.Unmarshal(rawPayload, &payload)
require.NoError(t, err)
if scopeField == jwt.JWTScopeFieldList || scopeField == jwt.JWTScopeFieldBoth {
scope, ok := payload["scp"]
require.True(t, ok)
assert.Equal(t, []interface{}{"email", "offline"}, scope)
}
if scopeField == jwt.JWTScopeFieldString || scopeField == jwt.JWTScopeFieldBoth {
scope, ok := payload["scope"]
require.True(t, ok)
assert.Equal(t, "email offline", scope)
}

validate := jWithField.signature(token)
err = jWithField.ValidateAccessToken(nil, c.r, token)
if c.pass {
assert.NoError(t, err)
assert.Equal(t, signature, validate)
} else {
assert.Error(t, err)
}
})
}
}
}
65 changes: 55 additions & 10 deletions token/jwt/claims_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@
package jwt

import (
"strings"
"time"

jwt "github.com/dgrijalva/jwt-go"
"github.com/pborman/uuid"
)

// Enum for different types of scope encoding.
type JWTScopeFieldEnum int

const (
JWTScopeFieldUnset JWTScopeFieldEnum = iota
JWTScopeFieldList
JWTScopeFieldString
JWTScopeFieldBoth
)

type JWTClaimsDefaults struct {
ExpiresAt time.Time
IssuedAt time.Time
Expand All @@ -43,21 +54,25 @@ type JWTClaimsContainer interface {
// values are already set in the claims, they will not be updated.
WithDefaults(iat time.Time, issuer string) JWTClaimsContainer

// WithScopeField configures how a scope field should be represented in JWT.
WithScopeField(scopeField JWTScopeFieldEnum) JWTClaimsContainer

// ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type.
ToMapClaims() jwt.MapClaims
}

// JWTClaims represent a token's claims.
type JWTClaims struct {
Subject string
Issuer string
Audience []string
JTI string
IssuedAt time.Time
NotBefore time.Time
ExpiresAt time.Time
Scope []string
Extra map[string]interface{}
Subject string
Issuer string
Audience []string
JTI string
IssuedAt time.Time
NotBefore time.Time
ExpiresAt time.Time
Scope []string
Extra map[string]interface{}
ScopeField JWTScopeFieldEnum
}

func (c *JWTClaims) With(expiry time.Time, scope, audience []string) JWTClaimsContainer {
Expand All @@ -78,6 +93,11 @@ func (c *JWTClaims) WithDefaults(iat time.Time, issuer string) JWTClaimsContaine
return c
}

func (c *JWTClaims) WithScopeField(scopeField JWTScopeFieldEnum) JWTClaimsContainer {
c.ScopeField = scopeField
return c
}

// ToMap will transform the headers to a map structure
func (c *JWTClaims) ToMap() map[string]interface{} {
var ret = Copy(c.Extra)
Expand All @@ -102,7 +122,13 @@ func (c *JWTClaims) ToMap() map[string]interface{} {
ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype

if c.Scope != nil {
ret["scp"] = c.Scope
// ScopeField default (when value is JWTScopeFieldUnset) is the list for backwards compatibility with old versions of fosite.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aeneasr Are you sure we want to keep backwards compatibility here? Existing behavior is not by the spec and it can be surprised that JWT tokens looks like they do not have scopes (because you are looking at the standard field). Should we change default here to be by the spec by default?

if c.ScopeField == JWTScopeFieldUnset || c.ScopeField == JWTScopeFieldList || c.ScopeField == JWTScopeFieldBoth {
ret["scp"] = c.Scope
}
if c.ScopeField == JWTScopeFieldString || c.ScopeField == JWTScopeFieldBoth {
ret["scope"] = strings.Join(c.Scope, " ")
}
}

return ret
Expand Down Expand Up @@ -156,13 +182,32 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) {
switch v.(type) {
case []string:
c.Scope = v.([]string)
if c.ScopeField == JWTScopeFieldString {
c.ScopeField = JWTScopeFieldBoth
} else if c.ScopeField == JWTScopeFieldUnset {
c.ScopeField = JWTScopeFieldList
}
case []interface{}:
c.Scope = make([]string, len(v.([]interface{})))
for i, vi := range v.([]interface{}) {
if s, ok := vi.(string); ok {
c.Scope[i] = s
}
}
if c.ScopeField == JWTScopeFieldString {
c.ScopeField = JWTScopeFieldBoth
} else if c.ScopeField == JWTScopeFieldUnset {
c.ScopeField = JWTScopeFieldList
}
}
case "scope":
if s, ok := v.(string); ok {
c.Scope = strings.Split(s, " ")
if c.ScopeField == JWTScopeFieldList {
c.ScopeField = JWTScopeFieldBoth
} else if c.ScopeField == JWTScopeFieldUnset {
c.ScopeField = JWTScopeFieldString
}
}
default:
c.Extra[k] = v
Expand Down
24 changes: 24 additions & 0 deletions token/jwt/claims_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var jwtClaims = &JWTClaims{
"foo": "bar",
"baz": "bar",
},
ScopeField: JWTScopeFieldList,
}

var jwtClaimsMap = map[string]interface{}{
Expand Down Expand Up @@ -89,3 +90,26 @@ func TestClaimsFromMap(t *testing.T) {
claims.FromMap(jwtClaimsMap)
assert.Equal(t, jwtClaims, &claims)
}

func TestScopeFieldString(t *testing.T) {
jwtClaimsWithString := jwtClaims.WithScopeField(JWTScopeFieldString)
// Making a copy of jwtClaimsMap.
jwtClaimsMapWithString := jwtClaims.ToMap()
delete(jwtClaimsMapWithString, "scp")
jwtClaimsMapWithString["scope"] = "email offline"
assert.Equal(t, jwtClaimsMapWithString, map[string]interface{}(jwtClaimsWithString.ToMapClaims()))
var claims JWTClaims
claims.FromMap(jwtClaimsMapWithString)
assert.Equal(t, jwtClaimsWithString, &claims)
}

func TestScopeFieldBoth(t *testing.T) {
jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth)
// Making a copy of jwtClaimsMap.
jwtClaimsMapWithBoth := jwtClaims.ToMap()
jwtClaimsMapWithBoth["scope"] = "email offline"
assert.Equal(t, jwtClaimsMapWithBoth, map[string]interface{}(jwtClaimsWithBoth.ToMapClaims()))
var claims JWTClaims
claims.FromMap(jwtClaimsMapWithBoth)
assert.Equal(t, jwtClaimsWithBoth, &claims)
}