Skip to content

Commit

Permalink
fix(continuity): properly reset cookies that became invalid
Browse files Browse the repository at this point in the history
Resolves several reports related to incorrect handling of invalid continuity issues.

Closes #2121
Closes ory-corp/cloud#1786
Closes #2016
Potentially #2108
  • Loading branch information
aeneasr committed Jan 8, 2022
1 parent 871ee04 commit 12f1306
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 131 deletions.
20 changes: 12 additions & 8 deletions continuity/manager_cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
var _ Manager = new(ManagerCookie)
var ErrNotResumable = *herodot.ErrBadRequest.WithError("no resumable session found").WithReasonf("The browser does not contain the neccesary cookie to resume the session. This is a security violation and was thus blocked. Please clear your browser's cookies and cache and try again!")

const cookieName = "ory_kratos_continuity"
const CookieName = "ory_kratos_continuity"

type (
managerCookieDependencies interface {
Expand Down Expand Up @@ -47,7 +47,7 @@ func (m *ManagerCookie) Pause(ctx context.Context, w http.ResponseWriter, r *htt
}
c := NewContainer(name, *o)

if err := x.SessionPersistValues(w, r, m.d.ContinuityCookieManager(ctx), cookieName, map[string]interface{}{
if err := x.SessionPersistValues(w, r, m.d.ContinuityCookieManager(ctx), CookieName, map[string]interface{}{
name: c.ID.String(),
}); err != nil {
return err
Expand Down Expand Up @@ -85,7 +85,7 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *
return nil, err
}

if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), cookieName, name); err != nil {
if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
return nil, err
}

Expand All @@ -94,11 +94,11 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *

func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, error) {
var sid uuid.UUID
if s, err := x.SessionGetString(r, m.d.ContinuityCookieManager(ctx), cookieName, name); err != nil {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), cookieName, name)
if s, err := x.SessionGetString(r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name)
return sid, errors.WithStack(ErrNotResumable.WithDebugf("%+v", err))
} else if sid = x.ParseUUID(s); sid == uuid.Nil {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), cookieName, name)
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name)
return sid, errors.WithStack(ErrNotResumable.WithDebug("session id is not a valid uuid"))
}

Expand All @@ -112,8 +112,12 @@ func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r
}

container, err := m.d.ContinuityPersister().GetContinuitySession(ctx, sid)
// If an error happens, we need to clean up the cookie.
if err != nil {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name)
}

if errors.Is(err, sqlcon.ErrNoRows) {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), cookieName, name)
return nil, errors.WithStack(ErrNotResumable.WithDebugf("Resumable ID from cookie could not be found in the datastore: %+v", err))
} else if err != nil {
return nil, err
Expand All @@ -131,7 +135,7 @@ func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http
return err
}

if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), cookieName, name); err != nil {
if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
return err
}

Expand Down
252 changes: 139 additions & 113 deletions continuity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/ory/x/ioutilx"
Expand Down Expand Up @@ -99,121 +100,146 @@ func TestManager(t *testing.T) {
return &http.Client{Jar: x.EasyCookieJar(t, nil)}
}

for name, p := range map[string]continuity.Manager{
"cookie": reg.ContinuityManager(),
p := reg.ContinuityManager()
cl := newClient()

t.Run("case=continue cookie resets when signature is invalid", func(t *testing.T) {
ts := newServer(t, p, new(persisterTestCase))
href := ts.URL + "/" + x.NewUUID().String()

res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

req := x.NewTestHTTPRequest(t, "GET", href, nil)
require.Len(t, res.Cookies(), 1)
for _, c := range res.Cookies() {
// Change something in the string
c.Value = strings.Replace(c.Value, "a", "b", 1)
req.AddCookie(c)
}
res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)

require.Len(t, res.Cookies(), 1, "continuing the flow with a broken cookie should instruct the browser to forget it")
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

for k, tc := range []persisterTestCase{
{},
{
ro: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"})},
wo: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{})},
expected: &persisterTestPayload{"bar"},
},
{
ro: []continuity.ManagerOption{continuity.WithIdentity(i)},
wo: []continuity.ManagerOption{continuity.WithIdentity(i)},
},
} {
t.Run(fmt.Sprintf("persister=%s", name), func(t *testing.T) {
for k, tc := range []persisterTestCase{
{},
{
ro: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"})},
wo: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{})},
expected: &persisterTestPayload{"bar"},
},
{
ro: []continuity.ManagerOption{continuity.WithIdentity(i)},
wo: []continuity.ManagerOption{continuity.WithIdentity(i)},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cl := newClient()
ts := newServer(t, p, &tc)
var genid = func() string {
return ts.URL + "/" + x.NewUUID().String()
}

t.Run("case=resume non-existing session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

body := ioutilx.MustReadAll(res.Body)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})

t.Run("case=pause and resume session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

body := ioutilx.MustReadAll(res.Body)
if tc.expectedErr != nil {
require.Equal(t, http.StatusGone, res.StatusCode, "%s", body)
return
}

require.Equal(t, http.StatusOK, res.StatusCode, "%s", body)

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

t.Run("case=pause and retry session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})

t.Run("case=pause and resume session in the same request", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "POST", href, nil))
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))

body := ioutilx.MustReadAll(res.Body)
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

t.Run("case=pause, abort, and continue session with failure", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "DELETE", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})
})
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cl := newClient()
ts := newServer(t, p, &tc)
var genid = func() string {
return ts.URL + "/" + x.NewUUID().String()
}

t.Run("case=resume non-existing session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

body := ioutilx.MustReadAll(res.Body)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})

t.Run("case=pause and resume session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

body := ioutilx.MustReadAll(res.Body)
if tc.expectedErr != nil {
require.Equal(t, http.StatusGone, res.StatusCode, "%s", body)
return
}

require.Equal(t, http.StatusOK, res.StatusCode, "%s", body)

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

t.Run("case=pause and retry session", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})

t.Run("case=pause and resume session in the same request", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "POST", href, nil))
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))

body := ioutilx.MustReadAll(res.Body)
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

t.Run("case=pause, abort, and continue session with failure", func(t *testing.T) {
href := genid()
res, err := cl.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "DELETE", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
require.Equal(t, http.StatusNoContent, res.StatusCode)

res, err = cl.Do(x.NewTestHTTPRequest(t, "GET", href, nil))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Equal(t, http.StatusBadRequest, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField)
})
})
}
}
14 changes: 4 additions & 10 deletions x/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,17 @@ func SessionGetStringOr(r *http.Request, s sessions.Store, id, key, fallback str
}

func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.Store, id string) error {
cookie, err := s.Get(r, id)
if err != nil {
return nil
}

cookie, _ := s.Get(r, id)
cookie.Options.MaxAge = -1
cookie.Values = make(map[interface{}]interface{})
return errors.WithStack(cookie.Save(r, w))
}

func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.Store, id, key string) error {
cookie, err := s.Get(r, id)
if err != nil {
return nil
} else if cookie.IsNew {
return nil
if err == nil {
delete(cookie.Values, key)
}

delete(cookie.Values, key)
return errors.WithStack(cookie.Save(r, w))
}
Loading

0 comments on commit 12f1306

Please sign in to comment.