Skip to content

Commit

Permalink
feat: "kid" support in JWT header and multiple keys.
Browse files Browse the repository at this point in the history
Ticket: MEN-6804
Changelog: title
Signed-off-by: Peter Grzybowski <[email protected]>
  • Loading branch information
merlin-northern committed Nov 17, 2023
1 parent e3ec595 commit fcfc84f
Show file tree
Hide file tree
Showing 15 changed files with 283 additions and 43 deletions.
2 changes: 1 addition & 1 deletion api/http/api_useradm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ func makeMockApiHandler(t *testing.T, uadm useradm.App, db store.DataStore) http
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.NoError(t, err)

jwth := jwt.NewJWTHandlerRS256(key)
jwth := jwt.NewJWTHandlerRS256(key, 0)

// API handler
handlers := NewUserAdmApiHandlers(uadm, db, jwth, Config{})
Expand Down
2 changes: 1 addition & 1 deletion authz/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (mw *AuthzMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFunc {
l := log.FromContext(r.Context())

//get token, no token header = http 401
tokstr, err := ExtractToken(r.Request)
tokstr, err := ExtractToken(r.Request) // here
if err != nil {
rest_utils.RestErrWithLog(w, r, l, err, http.StatusUnauthorized)
return
Expand Down
2 changes: 1 addition & 1 deletion authz/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func TestAuthzMiddleware(t *testing.T) {

//finish setting up the middleware
privkey := loadPrivKey("../crypto/private.pem", t)
jwth := jwt.NewJWTHandlerRS256(privkey)
jwth := jwt.NewJWTHandlerRS256(privkey, 0)
mw := AuthzMiddleware{
Authz: a,
ResFunc: resfunc,
Expand Down
45 changes: 45 additions & 0 deletions common/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2023 Northern.tech AS
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"path/filepath"
"regexp"
"strconv"

"github.com/pkg/errors"
)

const KeyIdZero = 0

var (
ErrKeyIdNotFound = errors.New("cant locate key by key id")
ErrKeyIdCollision = errors.New("key id already loaded")
)

func KeyIdFromPath(privateKeyPath string, privateKeyFilenamePattern string) (keyId int) {
fileName := filepath.Base(privateKeyPath)
r, _ := regexp.Compile(privateKeyFilenamePattern)
b := []byte(fileName)
indices := r.FindAllSubmatchIndex(b, -1)
keyId = KeyIdZero
if len(indices) > 0 && len(indices[0]) > 3 {
k, err := strconv.Atoi(string(b[indices[0][2]:indices[0][3]]))
if err == nil {
keyId = k
}
}
return keyId
}
8 changes: 6 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ const (
SettingMiddleware = "middleware"
SettingMiddlewareDefault = "prod"

SettingServerPrivKeyPath = "server_priv_key_path"
SettingServerPrivKeyPathDefault = "/etc/useradm/rsa/private.pem"
SettingServerPrivKeyPath = "server_priv_key_path"
SettingServerPrivKeyPathDefault = "/etc/useradm/rsa/private.pem"
SettingServerPrivKeyFileNamePattern = "server_priv_key_filename_pattern"
SettingServerPrivKeyFileNamePatternDefault = "private.id.([0-9]*).pem"

SettingServerFallbackPrivKeyPath = "server_fallback_priv_key_path"
SettingServerFallbackPrivKeyPathDefault = ""
Expand Down Expand Up @@ -73,6 +75,8 @@ var (
{Key: SettingListen, Value: SettingListenDefault},
{Key: SettingMiddleware, Value: SettingMiddlewareDefault},
{Key: SettingServerPrivKeyPath, Value: SettingServerPrivKeyPathDefault},
{Key: SettingServerPrivKeyFileNamePattern,
Value: SettingServerPrivKeyFileNamePatternDefault},
{Key: SettingServerFallbackPrivKeyPath, Value: SettingServerFallbackPrivKeyPathDefault},
{Key: SettingJWTIssuer, Value: SettingJWTIssuerDefault},
{Key: SettingJWTExpirationTimeout, Value: SettingJWTExpirationTimeoutDefault},
Expand Down
46 changes: 42 additions & 4 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"os"

"github.com/pkg/errors"

"github.com/mendersoftware/useradm/common"
)

var (
Expand All @@ -43,9 +45,10 @@ type Handler interface {
// ErrTokenExpired when the token is valid but expired
// ErrTokenInvalid when the token is invalid (malformed, missing required claims, etc.)
FromJWT(string) (*Token, error)
AddKey(privateKey interface{}, keyId int) (err error)
}

func NewJWTHandler(privateKeyPath string) (Handler, error) {
func NewJWTHandler(privateKeyPath string, privateKeyFilenamePattern string) (Handler, error) {
priv, err := os.ReadFile(privateKeyPath)
block, _ := pem.Decode(priv)
if block == nil {
Expand All @@ -57,18 +60,53 @@ func NewJWTHandler(privateKeyPath string) (Handler, error) {
if err != nil {
return nil, errors.Wrap(err, "failed to read rsa private key")
}
return NewJWTHandlerRS256(privKey), nil
return NewJWTHandlerRS256(
privKey,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
case pemHeaderPKCS8:
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read private key")
}
switch v := key.(type) {
case *rsa.PrivateKey:
return NewJWTHandlerRS256(v), nil
return NewJWTHandlerRS256(
v,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
case ed25519.PrivateKey:
return NewJWTHandlerEd25519(&v), nil
return NewJWTHandlerEd25519(
&v,
common.KeyIdFromPath(privateKeyPath, privateKeyFilenamePattern),
),
nil
}
}
return nil, errors.Errorf("unsupported server private key type")
}

func LoadKey(privateKeyPath string) (interface{}, error) {
priv, err := os.ReadFile(privateKeyPath)
block, _ := pem.Decode(priv)
if block == nil {
return nil, errors.Wrap(err, "failed to read private key")
}
switch block.Type {
case pemHeaderPKCS1:
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read rsa private key")
}
return privKey, nil
case pemHeaderPKCS8:
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read private key")
}
return key, nil
}
return nil, errors.Errorf("unsupported server private key type")
}
38 changes: 32 additions & 6 deletions jwt/jwt_ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,64 @@ package jwt

import (
"crypto/ed25519"
"strconv"

"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"

"github.com/mendersoftware/useradm/common"
)

// JWTHandlerEd25519 is an Ed25519-specific JWTHandler
type JWTHandlerEd25519 struct {
privKey *ed25519.PrivateKey
privKey map[int]*ed25519.PrivateKey
currentKeyId int
}

func NewJWTHandlerEd25519(privKey *ed25519.PrivateKey) *JWTHandlerEd25519 {
func NewJWTHandlerEd25519(privKey *ed25519.PrivateKey, keyId int) *JWTHandlerEd25519 {
return &JWTHandlerEd25519{
privKey: privKey,
privKey: map[int]*ed25519.PrivateKey{keyId: privKey},
currentKeyId: keyId,
}
}

func (j *JWTHandlerEd25519) AddKey(privateKey interface{}, keyId int) (err error) {
if _, exists := j.privKey[keyId]; exists {
return common.ErrKeyIdCollision
}
j.privKey[keyId] = privateKey.(*ed25519.PrivateKey)
if j.currentKeyId < keyId {
j.currentKeyId = keyId
}
return nil
}

func (j *JWTHandlerEd25519) ToJWT(token *Token) (string, error) {
//generate
jt := jwt.NewWithClaims(jwt.SigningMethodEdDSA, &token.Claims)

jt.Header["kid"] = token.KeyId
if _, exists := j.privKey[token.KeyId]; !exists {
return "", common.ErrKeyIdNotFound
}
//sign
data, err := jt.SignedString(j.privKey)
data, err := jt.SignedString(j.privKey[token.KeyId])
return data, err
}

func (j *JWTHandlerEd25519) FromJWT(tokstr string) (*Token, error) {
jwttoken, err := jwt.ParseWithClaims(tokstr, &Claims{},
func(token *jwt.Token) (interface{}, error) {
keyId := common.KeyIdZero
if _, ok := token.Header["kid"]; ok {
keyId = int(token.Header["kid"].(float64)) // TODO: check types just in case
}
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, errors.New("unexpected signing method: " + token.Method.Alg())
}
return j.privKey.Public(), nil
if _, exists := j.privKey[keyId]; !exists {
return nil, errors.New("cannot find the key with id " + strconv.Itoa(keyId))
}
return j.privKey[keyId].Public(), nil
},
)

Expand Down
6 changes: 3 additions & 3 deletions jwt/jwt_ed25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

func TestNewJWTHandlerEd25519(t *testing.T) {
privKey := loadEd25519PrivKey("./testdata/ed25519.pem", t)
jwtHandler := NewJWTHandlerEd25519(privKey)
jwtHandler := NewJWTHandlerEd25519(privKey, 0)

assert.NotNil(t, jwtHandler)
}
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestJWTHandlerEd25519GenerateToken(t *testing.T) {

for name, tc := range testCases {
t.Logf("test case: %s", name)
jwtHandler := NewJWTHandlerEd25519(tc.privKey)
jwtHandler := NewJWTHandlerEd25519(tc.privKey, 0)

raw, err := jwtHandler.ToJWT(&Token{
Claims: tc.claims,
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestJWTHandlerEd25519FromJWT(t *testing.T) {

for name, tc := range testCases {
t.Logf("test case: %s", name)
jwtHandler := NewJWTHandlerEd25519(tc.privKey)
jwtHandler := NewJWTHandlerEd25519(tc.privKey, 0)

token, err := jwtHandler.FromJWT(tc.inToken)
if tc.outErr == nil {
Expand Down
38 changes: 32 additions & 6 deletions jwt/jwt_rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,64 @@ package jwt

import (
"crypto/rsa"
"strconv"

"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"

"github.com/mendersoftware/useradm/common"
)

// JWTHandlerRS256 is an RS256-specific JWTHandler
type JWTHandlerRS256 struct {
privKey *rsa.PrivateKey
privKey map[int]*rsa.PrivateKey
currentKeyId int
}

func NewJWTHandlerRS256(privKey *rsa.PrivateKey) *JWTHandlerRS256 {
func NewJWTHandlerRS256(privKey *rsa.PrivateKey, keyId int) *JWTHandlerRS256 {
return &JWTHandlerRS256{
privKey: privKey,
privKey: map[int]*rsa.PrivateKey{keyId: privKey},
currentKeyId: keyId,
}
}

func (j *JWTHandlerRS256) AddKey(privateKey interface{}, keyId int) (err error) {
if _, exists := j.privKey[keyId]; exists {
return common.ErrKeyIdCollision
}
j.privKey[keyId] = privateKey.(*rsa.PrivateKey)
if j.currentKeyId < keyId {
j.currentKeyId = keyId
}
return nil
}

func (j *JWTHandlerRS256) ToJWT(token *Token) (string, error) {
//generate
jt := jwt.NewWithClaims(jwt.SigningMethodRS256, &token.Claims)

jt.Header["kid"] = token.KeyId
if _, exists := j.privKey[token.KeyId]; !exists {
return "", common.ErrKeyIdNotFound
}
//sign
data, err := jt.SignedString(j.privKey)
data, err := jt.SignedString(j.privKey[token.KeyId])
return data, err
}

func (j *JWTHandlerRS256) FromJWT(tokstr string) (*Token, error) {
jwttoken, err := jwt.ParseWithClaims(tokstr, &Claims{},
func(token *jwt.Token) (interface{}, error) {
keyId := common.KeyIdZero
if _, ok := token.Header["kid"]; ok {
keyId = int(token.Header["kid"].(float64)) // TODO: check types just in case
}
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, errors.New("unexpected signing method: " + token.Method.Alg())
}
return &j.privKey.PublicKey, nil
if _, exists := j.privKey[keyId]; !exists {
return nil, errors.New("cannot find the key with id " + strconv.Itoa(keyId))
}
return &j.privKey[keyId].PublicKey, nil
},
)

Expand Down
6 changes: 3 additions & 3 deletions jwt/jwt_rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

func TestNewJWTHandlerRS256(t *testing.T) {
privKey := loadRSAPrivKey("./testdata/rsa.pem", t)
jwtHandler := NewJWTHandlerRS256(privKey)
jwtHandler := NewJWTHandlerRS256(privKey, 0)

assert.NotNil(t, jwtHandler)
}
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestJWTHandlerRS256GenerateToken(t *testing.T) {

for name, tc := range testCases {
t.Logf("test case: %s", name)
jwtHandler := NewJWTHandlerRS256(tc.privKey)
jwtHandler := NewJWTHandlerRS256(tc.privKey, 0)

raw, err := jwtHandler.ToJWT(&Token{
Claims: tc.claims,
Expand Down Expand Up @@ -223,7 +223,7 @@ func TestJWTHandlerRS256FromJWT(t *testing.T) {

for name, tc := range testCases {
t.Logf("test case: %s", name)
jwtHandler := NewJWTHandlerRS256(tc.privKey)
jwtHandler := NewJWTHandlerRS256(tc.privKey, 0)

token, err := jwtHandler.FromJWT(tc.inToken)
if tc.outErr == nil {
Expand Down
2 changes: 1 addition & 1 deletion jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestNewJWTHandler(t *testing.T) {

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
_, err := NewJWTHandler(tc.privateKeyPath)
_, err := NewJWTHandler(tc.privateKeyPath, "pem")
if tc.err != nil {
assert.EqualError(t, err, tc.err.Error())
} else {
Expand Down
14 changes: 14 additions & 0 deletions jwt/mocks/Handler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fcfc84f

Please sign in to comment.