Skip to content

Commit

Permalink
Add support for public key per-token and add an example
Browse files Browse the repository at this point in the history
Relative to: #2
  • Loading branch information
kataras committed Feb 11, 2021
1 parent 5e6e0e3 commit eb8757e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 16 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Import as `import "github.com/kataras/jwt"` and use it as `jwt.XXX`.
* [Benchmarks](_benchmarks)
* [Examples](_examples)
* [Basic](_examples/basic/main.go)
* [Custom Header](_examples/custom-header/main.go)
* [HTTP Middleware](_examples/middleware/main.go)
* [Blocklist](_examples/blocklist/main.go)
* [JSON Required Tag](_examples/required/main.go)
Expand Down Expand Up @@ -139,6 +140,8 @@ claims := map[string]interface{}{
token, err := jwt.Sign(jwt.HS256, sharedKey, claims)
```

> See `SignWithHeader` too.
Example Code to merge map claims with standard claims:

```go
Expand Down Expand Up @@ -239,6 +242,8 @@ Verifying a Token is done through the `Verify` package-level function.
verifiedToken, err := jwt.Verify(jwt.HS256, sharedKey, token)
```

> See `VerifyWithHeaderValidator` too.
The `VerifiedToken` carries the token decoded information:

```go
Expand Down
95 changes: 95 additions & 0 deletions _examples/custom-header/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package main

import (
"fmt"
"log"
"time"

"github.com/kataras/jwt"
)

// Claims is an example of custom claims.
type Claims struct {
Email string `json:"email"`
}

func main() {
privateKey, err := jwt.LoadPrivateKeyRSA("../../_testfiles/rsa_private_key.pem")
if err != nil {
log.Fatal(err)
}

// Generate a token with custom claims and custom jwt header.
claims := Claims{Email: "[email protected]"}
header := Header{
Kid: "my_key_id_1",
Alg: jwt.RS256.Name(),
}
token, err := jwt.SignWithHeader(jwt.RS256, privateKey, claims, header, jwt.MaxAge(10*time.Minute))
if err != nil {
log.Fatal(err)
}
log.Printf("Generated token: %s", token)

// Verify the token with a custom header validator and public key per-token.
verifiedToken, err := jwt.VerifyWithHeaderValidator(jwt.RS256, nil, token, validateHeader)
if err != nil {
log.Fatal(err)
}

var getClaims Claims
err = verifiedToken.Claims(&getClaims)
if err != nil {
log.Fatal(err)
}

log.Printf("Verified claims: %#+v", getClaims)
}

var keys = map[string][]byte{
"my_key_id_1": []byte(`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAw6OJ4K9LUz6MugrF7uB+
/oZw8/f3J4CSPYZFXMTsWNVQSLlen6/pr7ZvyPsgLvBGikybxRu7ff6ufmHTWTm7
mlpxEv/bgFFUmfH/faY7SA1PJcWMaEMT6s7E96orefyTMNdLi4OKhUGYJ56L8cE1
yRIya+B2UMCg2ItK11TRQlHLwvKRGsFFirc23oHX8gMuduEkIb5dSD6rEaopR3ZM
O1tipfNrlCZs5kTaIubFRJ6K1xy2Rk2hVhqdaX6Ud2aWwrb7o21REkDbqY9YuOGV
/FnDiqDtIoS7MHl5CAguaL9YiOv3RRvCrUttfuHqbljlD7m6/69rMB1cVfbdr5IB
RQIDAQAB
-----END PUBLIC KEY-----
`),
// ...more keys
}

// Header is an example of custom header.
type Header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}

func validateHeader(alg string, headerDecoded []byte) (jwt.PublicKey, error) {
var h Header
err := jwt.Unmarshal(headerDecoded, &h)
if err != nil {
return nil, err
}

if h.Alg != alg {
return nil, jwt.ErrTokenAlg
}

if h.Kid == "" {
return nil, fmt.Errorf("kid is empty")
}

key, ok := keys[h.Kid]
if !ok {
return nil, fmt.Errorf("unknown kid")
}

publicKey, err := jwt.ParsePublicKeyRSA(key)
if err != nil {
return nil, jwt.ErrTokenAlg
}

return publicKey, nil
}
36 changes: 21 additions & 15 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,18 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV
}

// validate header equality.
if compareHeaderFunc != nil {
if err := compareHeaderFunc(alg.Name(), headerDecoded); err != nil {
return nil, nil, nil, err
}
} else {
if err := CompareHeader(alg.Name(), headerDecoded); err != nil {
return nil, nil, nil, err
}
if compareHeaderFunc == nil {
compareHeaderFunc = CompareHeader
}

pubKey, err := compareHeaderFunc(alg.Name(), headerDecoded)
if err != nil {
return nil, nil, nil, err
}

// Override the key given, which could be a nil if this "pubKey" always expected on success.
if pubKey != nil {
key = pubKey
}

signatureDecoded, err := Base64Decode(signature)
Expand Down Expand Up @@ -183,14 +187,16 @@ func createHeaderReversed(alg string) []byte {

// HeaderValidator is a function which can be used to customize how the header is validated,
// by default it makes sure the algorithm is the same as the "alg" field.
type HeaderValidator func(alg string, headerDecoded []byte) error
// It should return a nil PublicKey and a non-nil error on validation failure.
// On success, if public key is not nil then it overrides the VerifyXXX method's one.
type HeaderValidator func(alg string, headerDecoded []byte) (PublicKey, error)

// Note that this check is fully hard coded for known
// algorithms and it is fully hard coded in terms of
// its serialized format.
func compareHeader(alg string, headerDecoded []byte) error {
func compareHeader(alg string, headerDecoded []byte) (PublicKey, error) {
if len(headerDecoded) < 25 /* 28 but allow custom short algs*/ {
return ErrTokenAlg
return nil, ErrTokenAlg
}

// Fast check if the order is reversed.
Expand All @@ -200,18 +206,18 @@ func compareHeader(alg string, headerDecoded []byte) error {
if headerDecoded[2] == 't' {
expectedHeader := createHeaderReversed(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return ErrTokenAlg
return nil, ErrTokenAlg
}

return nil
return nil, nil
}

expectedHeader := createHeaderRaw(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return ErrTokenAlg
return nil, ErrTokenAlg
}

return nil
return nil, nil
}

func createSignature(alg Alg, key PrivateKey, headerAndPayload []byte) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCompareHeader(t *testing.T) {
}

for i, tt := range tests {
err := compareHeader(tt.alg, []byte(tt.header))
_, err := compareHeader(tt.alg, []byte(tt.header))
if tt.ok && err != nil {
t.Fatalf("[%d] expected to pass but got error: %v", i, err)
}
Expand Down

0 comments on commit eb8757e

Please sign in to comment.