Skip to content

Commit

Permalink
Allow for user provided context.Context in parsing (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
MicahParks authored Mar 25, 2024
1 parent accaf1a commit a6bfa43
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 47 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/MicahParks/keyfunc/v3
go 1.21

require (
github.com/MicahParks/jwkset v0.5.15
github.com/MicahParks/jwkset v0.5.17
github.com/golang-jwt/jwt/v5 v5.2.0
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/MicahParks/jwkset v0.5.15 h1:ACJY045Zuvo2TVWikeFLnKTIsEDQQHUHrNYiMW+gj24=
github.com/MicahParks/jwkset v0.5.15/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY=
github.com/MicahParks/jwkset v0.5.17 h1:DrcwyKwSP5adD0G2XJTvDulnWXjD6gbjROMgMXDbkKA=
github.com/MicahParks/jwkset v0.5.17/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
Expand Down
94 changes: 50 additions & 44 deletions keyfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,58 +112,64 @@ func NewJWKSetJSON(raw json.RawMessage) (Keyfunc, error) {
return New(options)
}

func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) {
kidInter, ok := token.Header[jwkset.HeaderKID]
if !ok {
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKeyfunc)
}
kid, ok := kidInter.(string)
if !ok {
return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKeyfunc)
}
algInter, ok := token.Header["alg"]
if !ok {
return nil, fmt.Errorf("%w: could not find alg in JWT header", ErrKeyfunc)
}
alg, ok := algInter.(string)
if !ok {
// For test coverage purposes, this should be impossible to reach because the JWT package rejects a token
// without an alg parameter in the header before calling jwt.Keyfunc.
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc)
}
func (k keyfunc) KeyfuncCtx(ctx context.Context) func(token *jwt.Token) (any, error) {
return func(token *jwt.Token) (any, error) {
kidInter, ok := token.Header[jwkset.HeaderKID]
if !ok {
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKeyfunc)
}
kid, ok := kidInter.(string)
if !ok {
return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKeyfunc)
}
algInter, ok := token.Header["alg"]
if !ok {
return nil, fmt.Errorf("%w: could not find alg in JWT header", ErrKeyfunc)
}
alg, ok := algInter.(string)
if !ok {
// For test coverage purposes, this should be impossible to reach because the JWT package rejects a token
// without an alg parameter in the header before calling jwt.Keyfunc.
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc)
}

jwk, err := k.storage.KeyRead(k.ctx, kid)
if err != nil {
return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc))
}
jwk, err := k.storage.KeyRead(ctx, kid)
if err != nil {
return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc))
}

if a := jwk.Marshal().ALG.String(); a != "" && a != alg {
return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrKeyfunc, a, alg)
}
if len(k.useWhitelist) > 0 {
found := false
for _, u := range k.useWhitelist {
if jwk.Marshal().USE == u {
found = true
break
if a := jwk.Marshal().ALG.String(); a != "" && a != alg {
return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrKeyfunc, a, alg)
}
if len(k.useWhitelist) > 0 {
found := false
for _, u := range k.useWhitelist {
if jwk.Marshal().USE == u {
found = true
break
}
}
if !found {
return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not in whitelist`, ErrKeyfunc, jwk.Marshal().USE)
}
}
if !found {
return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not in whitelist`, ErrKeyfunc, jwk.Marshal().USE)

type publicKeyer interface {
Public() crypto.PublicKey
}
}

type publicKeyer interface {
Public() crypto.PublicKey
}
key := jwk.Key()
pk, ok := key.(publicKeyer)
if ok {
key = pk.Public()
}

key := jwk.Key()
pk, ok := key.(publicKeyer)
if ok {
key = pk.Public()
return key, nil
}

return key, nil
}
func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) {
keyF := k.KeyfuncCtx(k.ctx)
return keyF(token)
}
func (k keyfunc) Storage() jwkset.Storage {
return k.storage
Expand Down

0 comments on commit a6bfa43

Please sign in to comment.