Skip to content

Commit

Permalink
fix: choose correct cookie when multiple are set
Browse files Browse the repository at this point in the history
Resolves an issue where, when multiple CSRF cookies are set, a random one would be used to verify the CSRF token. Now, regardless of how many conflicting CSRF cookies exist, if one of them is valid, the request will pass and clean up the cookie store.

See ory/kratos#2121
See ory-corp/cloud#1786
  • Loading branch information
aeneasr committed Jan 8, 2022
1 parent 2dc2119 commit d28b877
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 52 deletions.
38 changes: 30 additions & 8 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

w.Header().Add("Vary", "Cookie")

var realToken []byte

tokenCookie, err := r.Cookie(h.getCookieName(w, r))
if err == nil {
realToken = b64decode(tokenCookie.Value)
var realTokens [][]byte
for _, tokenCookie := range r.Cookies() {
if tokenCookie.Name == h.getCookieName(w, r) {
realTokens = append(realTokens, b64decode(tokenCookie.Value))
}
}

// If the length of the real token isn't what it should be,
Expand All @@ -151,12 +151,34 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
//
// As a consequence, CSRF check will fail when comparing the tokens later on,
// so we don't have to fail it just yet.
if len(realToken) != tokenLength {
if len(realTokens) == 0 {
// If we received no token (len==0), it means no CSRF cookie exists. We need to regenerate.
if !h.IsIgnored(r) {
h.RegenerateToken(w, r)
}
} else if len(realTokens) == 1 && len(realTokens[0]) != tokenLength {
// We received one token, but it's not the right length.
if !h.IsIgnored(r) {
h.RegenerateToken(w, r)
}
} else if len(realTokens) > 1 {
// We received multiple tokens. We need to find the correct one and set it.
sentToken := extractToken(r)
for _, realToken := range realTokens {
if verifyToken(realToken, sentToken) {
realTokens = [][]byte{realToken}
break
}
}

// We have to regenerate because we only want one CSRF cookie. This is like
// a cleanup job!
if !h.IsIgnored(r) {
h.RegenerateToken(w, r)
}
} else {
ctxSetToken(r, realToken)
// We received one token, and it's the right length
ctxSetToken(r, realTokens[0])
}

if h.IsIgnored(r) {
Expand Down Expand Up @@ -195,7 +217,7 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Finally, we check the token itself.
sentToken := extractToken(r)

if !verifyToken(realToken, sentToken) {
if !verifyToken(realTokens[0], sentToken) {
ctxSetReason(r, ErrBadToken)
h.handleFailure(w, r)
return
Expand Down
174 changes: 130 additions & 44 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,37 +317,11 @@ func TestWrongTokenFails(t *testing.T) {
}
}

func TestCustomCookieName(t *testing.T) {
hand := New(http.HandlerFunc(succHand))

if hand.getCookieName(nil, nil) != CookieName {
t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil))
}

hand.SetBaseCookie(http.Cookie{})

if hand.getCookieName(nil, nil) != CookieName {
t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil))
}

customCookieName := "my_custom_cookie"
hand.SetBaseCookie(http.Cookie{
Name: customCookieName,
})

if hand.getCookieName(nil, nil) != customCookieName {
t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil))
}
}

// For this and similar tests we start a test server
// Since it's much easier to get the cookie
// from a normal http.Response than from the recorder
func TestCorrectTokenPasses(t *testing.T) {
func TestFunctionalCases(t *testing.T) {
hand := New(http.HandlerFunc(succHand))
hand.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("Test failed. Reason: %v", Reason(r))
}))

server := httptest.NewServer(hand)
defer server.Close()
Expand All @@ -371,26 +345,115 @@ func TestCorrectTokenPasses(t *testing.T) {
}

// Test usual POST
/*
{
req, err := http.NewRequest("POST", server.URL, formBodyR(vals))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(cookie)
{
req, err := http.NewRequest("POST", server.URL, formBodyR(vals))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(cookie)

resp, err = http.DefaultClient.Do(req)
resp, err = http.DefaultClient.Do(req)

if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d",
resp.StatusCode)
}
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d",
resp.StatusCode)
}
*/
}

// Test multiple cookies (pass)
{
req, err := http.NewRequest("POST", server.URL, formBodyR(vals))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Add cookies with invalid values
for i := 0; i < 10; i++ {
cookieCopy := *cookie
cookieCopy.Value = b64encode(generateToken())
req.AddCookie(&cookieCopy)
}

// Add the real cookie.
req.AddCookie(cookie)

resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != 200 {
t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d",
resp.StatusCode)
}

if len(resp.Cookies()) != 1 {
t.Errorf("The request should have have included exactly one cookie but got %+v",
resp.Cookies())
}
}

// Test multiple cookies (fail)
{
req, err := http.NewRequest("POST", server.URL, formBodyR(vals))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Add cookies with invalid values
for i := 0; i < 10; i++ {
cookieCopy := *cookie
cookieCopy.Value = b64encode(generateToken())
req.AddCookie(&cookieCopy)
}

resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != FailureCode {
t.Errorf("The request should have failed, but it didn't. Instead, the code was %d",
resp.StatusCode)
}

if len(resp.Cookies()) != 1 {
t.Errorf("The request should have have included exactly one cookie but got %+v",
resp.Cookies())
}
}

// Test one cookie (fail)
{
req, err := http.NewRequest("POST", server.URL, formBodyR(vals))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Add cookies with invalid values
cookieCopy := *cookie
cookieCopy.Value = b64encode(generateToken())
req.AddCookie(&cookieCopy)

resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != FailureCode {
t.Errorf("The request should have failed, but it didn't. Instead, the code was %d",
resp.StatusCode)
}

if len(resp.Cookies()) != 0 {
t.Errorf("The request should not have set a cookie, but got %+v",
resp.Cookies())
}
}

// Test multipart
{
Expand Down Expand Up @@ -432,6 +495,29 @@ func TestCorrectTokenPasses(t *testing.T) {
}
}

func TestCustomCookieName(t *testing.T) {
hand := New(http.HandlerFunc(succHand))

if hand.getCookieName(nil, nil) != CookieName {
t.Errorf("No base cookie set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil))
}

hand.SetBaseCookie(http.Cookie{})

if hand.getCookieName(nil, nil) != CookieName {
t.Errorf("Base cookie with empty name set, expected CookieName to be %s, was %s", CookieName, hand.getCookieName(nil, nil))
}

customCookieName := "my_custom_cookie"
hand.SetBaseCookie(http.Cookie{
Name: customCookieName,
})

if hand.getCookieName(nil, nil) != customCookieName {
t.Errorf("Base cookie with name %s was set, but CookieName was %s instead", customCookieName, hand.getCookieName(nil, nil))
}
}

func TestPrefersHeaderOverFormValue(t *testing.T) {
// Let's do a nice trick to find out this:
// We'll set the correct token in the header
Expand Down

0 comments on commit d28b877

Please sign in to comment.