Skip to content

Commit

Permalink
breaking change: HeaderValidator: returns an extra output argument wh…
Browse files Browse the repository at this point in the history
…ich can optionally (if not nil) set the decryption method dynamically based on the kid
  • Loading branch information
kataras committed Mar 27, 2022
1 parent 4ee796d commit d03e03a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
14 changes: 7 additions & 7 deletions _examples/custom-header/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,30 @@ type Header struct {
Alg string `json:"alg"`
}

func validateHeader(alg string, headerDecoded []byte) (jwt.PublicKey, error) {
func validateHeader(alg string, headerDecoded []byte) (jwt.Alg, jwt.PublicKey, jwt.InjectFunc, error) {
var h Header
err := jwt.Unmarshal(headerDecoded, &h)
if err != nil {
return nil, err
return nil, nil, nil, err
}

if h.Alg != alg {
return nil, jwt.ErrTokenAlg
return nil, nil, nil, jwt.ErrTokenAlg
}

if h.Kid == "" {
return nil, fmt.Errorf("kid is empty")
return nil, nil, nil, fmt.Errorf("kid is empty")
}

key, ok := keys[h.Kid]
if !ok {
return nil, fmt.Errorf("unknown kid")
return nil, nil, nil, fmt.Errorf("unknown kid")
}

publicKey, err := jwt.ParsePublicKeyRSA(key)
if err != nil {
return nil, jwt.ErrTokenAlg
return nil, nil, nil, jwt.ErrTokenAlg
}

return publicKey, nil
return nil, publicKey, nil, nil
}
39 changes: 30 additions & 9 deletions kid_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type (
Public PublicKey
Private PrivateKey
MaxAge time.Duration // optional.
Encrypt InjectFunc // optional.
Decrypt InjectFunc // optional.
}

// Keys is a map which holds the key id and a key pair.
Expand Down Expand Up @@ -75,11 +77,20 @@ type (
Alg string `json:"alg" yaml:"Alg" toml:"Alg" ini:"alg"`
Private string `json:"private" yaml:"Private" toml:"Private" ini:"private"`
Public string `json:"public" yaml:"Public" toml:"Public" ini:"public"`
// Token expiration. Optional.
// MaxAge sets the token expiration. It is optional.
// If greater than zero then the MaxAge token validation
// will be appended to the "VerifyToken" and the token is invalid
// after expiration of its sign time.
MaxAge time.Duration `json:"max_age" yaml:"MaxAge" toml:"MaxAge" ini:"max_age"`

// EncryptionKey enables encryption on the generated token. It is optional.
// Encryption using the Galois Counter mode of operation with
// AES cipher symmetric-key cryptographic.
//
// The value should be the AES key,
// either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256.
EncryptionKey string `json:"encryption_key" yaml:"EncryptionKey" toml:"EncryptionKey" ini:"encryption_key"`
}
)

Expand Down Expand Up @@ -131,6 +142,16 @@ func (c KeysConfiguration) Load() (Keys, error) {
p.Public = entry.Public
}

if entry.EncryptionKey != "" {
encrypt, decrypt, err := GCM([]byte(entry.EncryptionKey), nil)
if err != nil {
return nil, fmt.Errorf("jwt: load keys: build encryption: %w", err)
}

p.Encrypt = encrypt
p.Decrypt = decrypt
}

parsedKeys[entry.ID] = p
}

Expand All @@ -155,33 +176,33 @@ func (keys Keys) Register(alg Alg, kid string, pubKey PublicKey, privKey Private

// ValidateHeader validates the given json header value (base64 decoded) based on the "keys".
// Keys structure completes the `HeaderValidator` interface.
func (keys Keys) ValidateHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
func (keys Keys) ValidateHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
var h HeaderWithKid

err := Unmarshal(headerDecoded, &h)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

if h.Kid == "" {
return nil, nil, ErrEmptyKid
return nil, nil, nil, ErrEmptyKid
}

key, ok := keys.Get(h.Kid)
if !ok {
return nil, nil, ErrUnknownKid
return nil, nil, nil, ErrUnknownKid
}

if h.Alg != key.Alg.Name() {
return nil, nil, ErrTokenAlg
return nil, nil, nil, ErrTokenAlg
}

// If for some reason a specific alg was given by the caller then check that as well.
if alg != "" && alg != h.Alg {
return nil, nil, ErrTokenAlg
return nil, nil, nil, ErrTokenAlg
}

return key.Alg, key.Public, nil
return key.Alg, key.Public, key.Decrypt, nil
}

// SignToken signs the "claims" using the given "alg" based a specific key.
Expand All @@ -195,7 +216,7 @@ func (keys Keys) SignToken(kid string, claims interface{}, opts ...SignOption) (
opts = append([]SignOption{MaxAge(k.MaxAge)}, opts...)
}

return SignWithHeader(k.Alg, k.Private, claims, HeaderWithKid{
return SignEncryptedWithHeader(k.Alg, k.Private, k.Encrypt, claims, HeaderWithKid{
Kid: kid,
Alg: k.Alg.Name(),
}, opts...)
Expand Down
28 changes: 19 additions & 9 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV
algName = alg.Name()
}

dynamicAlg, pubKey, err := compareHeaderFunc(algName, headerDecoded)
dynamicAlg, pubKey, decrypt, err := compareHeaderFunc(algName, headerDecoded)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -113,6 +113,14 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV
if err != nil {
return nil, nil, nil, err
}

if decrypt != nil {
payload, err = decrypt(payload)
if err != nil {
return nil, nil, nil, err
}
}

return headerDecoded, payload, signatureDecoded, nil
}

Expand Down Expand Up @@ -196,22 +204,24 @@ func createHeaderWithoutTyp(alg string) []byte {
// If the "alg" is empty then this function should return a non-nil algorithm
// based on the token contents.
// It should return a nil PublicKey and a non-nil error on validation failure.
// The out InjectFunc is optional. If it's not nil then decryption of the payload
// using GCM (AES key) is performed before verification.
// On success, if public key is not nil then it overrides the VerifyXXX method's one.
type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, error)
type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error)

// Note that this check is fully hard coded for known
// algorithms and it is fully hard coded in terms of
// its serialized format.
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
if n := len(headerDecoded); n < 25 /* 28 but allow custom short algs*/ {
if n == 15 { // header without "typ": "JWT".
expectedHeader := createHeaderWithoutTyp(alg)
if bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil
return nil, nil, nil, nil
}
}

return nil, nil, ErrTokenAlg
return nil, nil, nil, ErrTokenAlg
}

// Fast check if the order is reversed.
Expand All @@ -221,18 +231,18 @@ func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
if headerDecoded[2] == 't' {
expectedHeader := createHeaderReversed(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, ErrTokenAlg
return nil, nil, nil, ErrTokenAlg
}

return nil, nil, nil
return nil, nil, nil, nil
}

expectedHeader := createHeaderRaw(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, ErrTokenAlg
return nil, nil, nil, ErrTokenAlg
}

return nil, nil, nil
return nil, nil, nil, nil
}

func createSignature(alg Alg, key PrivateKey, headerAndPayload []byte) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCompareHeader(t *testing.T) {
}

for i, tt := range tests {
_, _, err := compareHeader(tt.alg, []byte(tt.header))
_, _, _, err := compareHeader(tt.alg, []byte(tt.header))
if tt.ok && err != nil {
t.Fatalf("[%d] expected to pass but got error: %v", i, err)
}
Expand Down

0 comments on commit d03e03a

Please sign in to comment.