diff --git a/handler.go b/handler.go index 254029e..318a227 100644 --- a/handler.go +++ b/handler.go @@ -136,11 +136,11 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("Vary", "Cookie") - var realToken []byte - - tokenCookie, err := r.Cookie(h.getCookieName(w, r)) - if err == nil { - realToken = b64decode(tokenCookie.Value) + var realTokens [][]byte + for _, tokenCookie := range r.Cookies() { + if tokenCookie.Name == h.getCookieName(w, r) { + realTokens = append(realTokens, b64decode(tokenCookie.Value)) + } } // If the length of the real token isn't what it should be, @@ -151,12 +151,34 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // // As a consequence, CSRF check will fail when comparing the tokens later on, // so we don't have to fail it just yet. - if len(realToken) != tokenLength { + if len(realTokens) == 0 { + // If we received no token (len==0), it means no CSRF cookie exists. We need to regenerate. + if !h.IsIgnored(r) { + h.RegenerateToken(w, r) + } + } else if len(realTokens) == 1 && len(realTokens[0]) != tokenLength { + // We received one token, but it's not the right length. + if !h.IsIgnored(r) { + h.RegenerateToken(w, r) + } + } else if len(realTokens) > 1 { + // We received multiple tokens. We need to find the correct one and set it. + sentToken := extractToken(r) + for _, realToken := range realTokens { + if verifyToken(realToken, sentToken) { + realTokens = [][]byte{realToken} + break + } + } + + // We have to regenerate because we only want one CSRF cookie. This is like + // a cleanup job! if !h.IsIgnored(r) { h.RegenerateToken(w, r) } } else { - ctxSetToken(r, realToken) + // We received one token, and it's the right length + ctxSetToken(r, realTokens[0]) } if h.IsIgnored(r) { @@ -195,7 +217,7 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Finally, we check the token itself. sentToken := extractToken(r) - if !verifyToken(realToken, sentToken) { + if !verifyToken(realTokens[0], sentToken) { ctxSetReason(r, ErrBadToken) h.handleFailure(w, r) return diff --git a/handler_test.go b/handler_test.go index 4ed3929..d719679 100644 --- a/handler_test.go +++ b/handler_test.go @@ -317,37 +317,11 @@ func TestWrongTokenFails(t *testing.T) { } } -func TestCustomCookieName(t *testing.T) { - hand := New(http.HandlerFunc(succHand)) - - if hand.getCookieName(nil, nil) != CookieName { - t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) - } - - hand.SetBaseCookie(http.Cookie{}) - - if hand.getCookieName(nil, nil) != CookieName { - t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) - } - - customCookieName := "my_custom_cookie" - hand.SetBaseCookie(http.Cookie{ - Name: customCookieName, - }) - - if hand.getCookieName(nil, nil) != customCookieName { - t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil)) - } -} - // For this and similar tests we start a test server // Since it's much easier to get the cookie // from a normal http.Response than from the recorder -func TestCorrectTokenPasses(t *testing.T) { +func TestFunctionalCases(t *testing.T) { hand := New(http.HandlerFunc(succHand)) - hand.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Errorf("Test failed. Reason: %v", Reason(r)) - })) server := httptest.NewServer(hand) defer server.Close() @@ -371,26 +345,115 @@ func TestCorrectTokenPasses(t *testing.T) { } // Test usual POST - /* - { - req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.AddCookie(cookie) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(cookie) - resp, err = http.DefaultClient.Do(req) + resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != 200 { - t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", - resp.StatusCode) - } + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", + resp.StatusCode) } - */ + } + + // Test multiple cookies (pass) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + for i := 0; i < 10; i++ { + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + } + + // Add the real cookie. + req.AddCookie(cookie) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 200 { + t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 1 { + t.Errorf("The request should have have included exactly one cookie but got %+v", + resp.Cookies()) + } + } + + // Test multiple cookies (fail) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + for i := 0; i < 10; i++ { + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + } + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != FailureCode { + t.Errorf("The request should have failed, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 1 { + t.Errorf("The request should have have included exactly one cookie but got %+v", + resp.Cookies()) + } + } + + // Test one cookie (fail) + { + req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Add cookies with invalid values + cookieCopy := *cookie + cookieCopy.Value = b64encode(generateToken()) + req.AddCookie(&cookieCopy) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != FailureCode { + t.Errorf("The request should have failed, but it didn't. Instead, the code was %d", + resp.StatusCode) + } + + if len(resp.Cookies()) != 0 { + t.Errorf("The request should not have set a cookie, but got %+v", + resp.Cookies()) + } + } // Test multipart { @@ -432,6 +495,29 @@ func TestCorrectTokenPasses(t *testing.T) { } } +func TestCustomCookieName(t *testing.T) { + hand := New(http.HandlerFunc(succHand)) + + if hand.getCookieName(nil, nil) != CookieName { + t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) + } + + hand.SetBaseCookie(http.Cookie{}) + + if hand.getCookieName(nil, nil) != CookieName { + t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil)) + } + + customCookieName := "my_custom_cookie" + hand.SetBaseCookie(http.Cookie{ + Name: customCookieName, + }) + + if hand.getCookieName(nil, nil) != customCookieName { + t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil)) + } +} + func TestPrefersHeaderOverFormValue(t *testing.T) { // Let's do a nice trick to find out this: // We'll set the correct token in the header