From c4e5496a4fec282d96dec504d279b4fab7dcf896 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Sat, 8 Jan 2022 10:24:05 +0100 Subject: [PATCH] fix: choose correct cookie when multiple are set Resolves an issue where, when multiple CSRF cookies are set, a random one would be used to verify the CSRF token. Now, regardless of how many conflicting CSRF cookies exist, if one of them is valid, the request will pass and clean up the cookie store. See https://github.com/ory/kratos/issues/2121 See https://github.com/ory-corp/cloud/issues/1786 --- handler.go | 38 ++++++++--- handler_test.go | 174 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 160 insertions(+), 52 deletions(-) diff --git a/handler.go b/handler.go index 254029e..318a227 100644 --- a/handler.go +++ b/handler.go @@ -136,11 +136,11 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("Vary", "Cookie") - var realToken []byte - - tokenCookie, err := r.Cookie(h.getCookieName(w, r)) - if err == nil { - realToken = b64decode(tokenCookie.Value) + var realTokens [][]byte + for _, tokenCookie := range r.Cookies() { + if tokenCookie.Name == h.getCookieName(w, r) { + realTokens = append(realTokens, b64decode(tokenCookie.Value)) + } } // If the length of the real token isn't what it should be, @@ -151,12 +151,34 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // // As a consequence, CSRF check will fail when comparing the tokens later on, // so we don't have to fail it just yet. - if len(realToken) != tokenLength { + if len(realTokens) == 0 { + // If we received no token (len==0), it means no CSRF cookie exists. We need to regenerate. + if !h.IsIgnored(r) { + h.RegenerateToken(w, r) + } + } else if len(realTokens) == 1 && len(realTokens[0]) != tokenLength { + // We received one token, but it's not the right length. + if !h.IsIgnored(r) { + h.RegenerateToken(w, r) + } + } else if len(realTokens) > 1 { + // We received multiple tokens. We need to find the correct one and set it. + sentToken := extractToken(r) + for _, realToken := range realTokens { + if verifyToken(realToken, sentToken) { + realTokens = [][]byte{realToken} + break + } + } + + // We have to regenerate because we only want one CSRF cookie. This is like + // a cleanup job! if !h.IsIgnored(r) { h.RegenerateToken(w, r) } } else { - ctxSetToken(r, realToken) + // We received one token, and it's the right length + ctxSetToken(r, realTokens[0]) } if h.IsIgnored(r) { @@ -195,7 +217,7 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Finally, we check the token itself. sentToken := extractToken(r) - if !verifyToken(realToken, sentToken) { + if !verifyToken(realTokens[0], sentToken) { ctxSetReason(r, ErrBadToken) h.handleFailure(w, r) return diff --git a/handler_test.go b/handler_test.go index 4ed3929..d719679 100644 --- a/handler_test.go +++ b/handler_test.go @@ -317,37 +317,11 @@ func TestWrongTokenFails(t *testing.T) { } } -func TestCustomCookieName(t *testing.T) { - hand := New(http.HandlerFunc(succHand)) - - if hand.getCookieName(nil, nil) != CookieName { - t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) - } - - hand.SetBaseCookie(http.Cookie{}) - - if hand.getCookieName(nil, nil) != CookieName { - t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) - } - - customCookieName := "my_custom_cookie" - hand.SetBaseCookie(http.Cookie{ - Name: customCookieName, - }) - - if hand.getCookieName(nil, nil) != customCookieName { - t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil)) - } -} - // For this and similar tests we start a test server // Since it's much easier to get the cookie // from a normal http.Response than from the recorder -func TestCorrectTokenPasses(t *testing.T) { +func TestFunctionalCases(t *testing.T) { hand := New(http.HandlerFunc(succHand)) - hand.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Errorf("Test failed. Reason: %v", Reason(r)) - })) server := httptest.NewServer(hand) defer server.Close() @@ -371,26 +345,115 @@ func TestCorrectTokenPasses(t *testing.T) { } // Test usual POST - /* - { - req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.AddCookie(cookie) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(cookie) - resp, err = http.DefaultClient.Do(req) + resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != 200 { - t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", - resp.StatusCode) - } + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", + resp.StatusCode) } - */ + } + + // Test multiple cookies (pass) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + for i := 0; i < 10; i++ { + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + } + + // Add the real cookie. + req.AddCookie(cookie) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 200 { + t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 1 { + t.Errorf("The request should have have included exactly one cookie but got %+v", + resp.Cookies()) + } + } + + // Test multiple cookies (fail) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + for i := 0; i < 10; i++ { + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + } + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != FailureCode { + t.Errorf("The request should have failed, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 1 { + t.Errorf("The request should have have included exactly one cookie but got %+v", + resp.Cookies()) + } + } + + // Test one cookie (fail) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != FailureCode { + t.Errorf("The request should have failed, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 0 { + t.Errorf("The request should not have set a cookie, but got %+v", + resp.Cookies()) + } + } // Test multipart { @@ -432,6 +495,29 @@ func TestCorrectTokenPasses(t *testing.T) { } } +func TestCustomCookieName(t *testing.T) { + hand := New(http.HandlerFunc(succHand)) + + if hand.getCookieName(nil, nil) != CookieName { + t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) + } + + hand.SetBaseCookie(http.Cookie{}) + + if hand.getCookieName(nil, nil) != CookieName { + t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) + } + + customCookieName := "my_custom_cookie" + hand.SetBaseCookie(http.Cookie{ + Name: customCookieName, + }) + + if hand.getCookieName(nil, nil) != customCookieName { + t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil)) + } +} + func TestPrefersHeaderOverFormValue(t *testing.T) { // Let's do a nice trick to find out this: // We'll set the correct token in the header