diff --git a/README.md b/README.md index 0cd1ac4..cb67fd2 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Import as `import "github.com/kataras/jwt"` and use it as `jwt.XXX`. * [Benchmarks](_benchmarks) * [Examples](_examples) * [Basic](_examples/basic/main.go) + * [Custom Header](_examples/custom-header/main.go) * [HTTP Middleware](_examples/middleware/main.go) * [Blocklist](_examples/blocklist/main.go) * [JSON Required Tag](_examples/required/main.go) @@ -139,6 +140,8 @@ claims := map[string]interface{}{ token, err := jwt.Sign(jwt.HS256, sharedKey, claims) ``` +> See `SignWithHeader` too. + Example Code to merge map claims with standard claims: ```go @@ -239,6 +242,8 @@ Verifying a Token is done through the `Verify` package-level function. verifiedToken, err := jwt.Verify(jwt.HS256, sharedKey, token) ``` +> See `VerifyWithHeaderValidator` too. + The `VerifiedToken` carries the token decoded information: ```go diff --git a/_examples/custom-header/main.go b/_examples/custom-header/main.go new file mode 100644 index 0000000..6087c5b --- /dev/null +++ b/_examples/custom-header/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/kataras/jwt" +) + +// Claims is an example of custom claims. +type Claims struct { + Email string `json:"email"` +} + +func main() { + privateKey, err := jwt.LoadPrivateKeyRSA("../../_testfiles/rsa_private_key.pem") + if err != nil { + log.Fatal(err) + } + + // Generate a token with custom claims and custom jwt header. + claims := Claims{Email: "kataras2006@hotmail.com"} + header := Header{ + Kid: "my_key_id_1", + Alg: jwt.RS256.Name(), + } + token, err := jwt.SignWithHeader(jwt.RS256, privateKey, claims, header, jwt.MaxAge(10*time.Minute)) + if err != nil { + log.Fatal(err) + } + log.Printf("Generated token: %s", token) + + // Verify the token with a custom header validator and public key per-token. + verifiedToken, err := jwt.VerifyWithHeaderValidator(jwt.RS256, nil, token, validateHeader) + if err != nil { + log.Fatal(err) + } + + var getClaims Claims + err = verifiedToken.Claims(&getClaims) + if err != nil { + log.Fatal(err) + } + + log.Printf("Verified claims: %#+v", getClaims) +} + +var keys = map[string][]byte{ + "my_key_id_1": []byte(`-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAw6OJ4K9LUz6MugrF7uB+ +/oZw8/f3J4CSPYZFXMTsWNVQSLlen6/pr7ZvyPsgLvBGikybxRu7ff6ufmHTWTm7 +mlpxEv/bgFFUmfH/faY7SA1PJcWMaEMT6s7E96orefyTMNdLi4OKhUGYJ56L8cE1 +yRIya+B2UMCg2ItK11TRQlHLwvKRGsFFirc23oHX8gMuduEkIb5dSD6rEaopR3ZM +O1tipfNrlCZs5kTaIubFRJ6K1xy2Rk2hVhqdaX6Ud2aWwrb7o21REkDbqY9YuOGV +/FnDiqDtIoS7MHl5CAguaL9YiOv3RRvCrUttfuHqbljlD7m6/69rMB1cVfbdr5IB +RQIDAQAB +-----END PUBLIC KEY----- +`), + // ...more keys +} + +// Header is an example of custom header. +type Header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` +} + +func validateHeader(alg string, headerDecoded []byte) (jwt.PublicKey, error) { + var h Header + err := jwt.Unmarshal(headerDecoded, &h) + if err != nil { + return nil, err + } + + if h.Alg != alg { + return nil, jwt.ErrTokenAlg + } + + if h.Kid == "" { + return nil, fmt.Errorf("kid is empty") + } + + key, ok := keys[h.Kid] + if !ok { + return nil, fmt.Errorf("unknown kid") + } + + publicKey, err := jwt.ParsePublicKeyRSA(key) + if err != nil { + return nil, jwt.ErrTokenAlg + } + + return publicKey, nil +} diff --git a/token.go b/token.go index 794989b..9631f56 100644 --- a/token.go +++ b/token.go @@ -74,14 +74,18 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV } // validate header equality. - if compareHeaderFunc != nil { - if err := compareHeaderFunc(alg.Name(), headerDecoded); err != nil { - return nil, nil, nil, err - } - } else { - if err := CompareHeader(alg.Name(), headerDecoded); err != nil { - return nil, nil, nil, err - } + if compareHeaderFunc == nil { + compareHeaderFunc = CompareHeader + } + + pubKey, err := compareHeaderFunc(alg.Name(), headerDecoded) + if err != nil { + return nil, nil, nil, err + } + + // Override the key given, which could be a nil if this "pubKey" always expected on success. + if pubKey != nil { + key = pubKey } signatureDecoded, err := Base64Decode(signature) @@ -183,14 +187,16 @@ func createHeaderReversed(alg string) []byte { // HeaderValidator is a function which can be used to customize how the header is validated, // by default it makes sure the algorithm is the same as the "alg" field. -type HeaderValidator func(alg string, headerDecoded []byte) error +// It should return a nil PublicKey and a non-nil error on validation failure. +// On success, if public key is not nil then it overrides the VerifyXXX method's one. +type HeaderValidator func(alg string, headerDecoded []byte) (PublicKey, error) // Note that this check is fully hard coded for known // algorithms and it is fully hard coded in terms of // its serialized format. -func compareHeader(alg string, headerDecoded []byte) error { +func compareHeader(alg string, headerDecoded []byte) (PublicKey, error) { if len(headerDecoded) < 25 /* 28 but allow custom short algs*/ { - return ErrTokenAlg + return nil, ErrTokenAlg } // Fast check if the order is reversed. @@ -200,18 +206,18 @@ func compareHeader(alg string, headerDecoded []byte) error { if headerDecoded[2] == 't' { expectedHeader := createHeaderReversed(alg) if !bytes.Equal(expectedHeader, headerDecoded) { - return ErrTokenAlg + return nil, ErrTokenAlg } - return nil + return nil, nil } expectedHeader := createHeaderRaw(alg) if !bytes.Equal(expectedHeader, headerDecoded) { - return ErrTokenAlg + return nil, ErrTokenAlg } - return nil + return nil, nil } func createSignature(alg Alg, key PrivateKey, headerAndPayload []byte) ([]byte, error) { diff --git a/token_test.go b/token_test.go index 85d0538..e26e417 100644 --- a/token_test.go +++ b/token_test.go @@ -93,7 +93,7 @@ func TestCompareHeader(t *testing.T) { } for i, tt := range tests { - err := compareHeader(tt.alg, []byte(tt.header)) + _, err := compareHeader(tt.alg, []byte(tt.header)) if tt.ok && err != nil { t.Fatalf("[%d] expected to pass but got error: %v", i, err) }