Skip to content

Commit

Permalink
fix: add ability to resume continuity sessions from several cookies (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored Jan 10, 2022
1 parent 8c21d3a commit 2b521e0
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 34 deletions.
10 changes: 7 additions & 3 deletions continuity/manager_cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *
}
}

if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil {
if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
return nil, err
}

if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) {
return nil, err
}

Expand Down Expand Up @@ -139,5 +139,9 @@ func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http
return err
}

return errors.WithStack(m.d.ContinuityPersister().DeleteContinuitySession(ctx, sid))
if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, sid); err != nil && !errors.Is(err, sqlcon.ErrNoRows) {
return errors.WithStack(err)
}

return nil
}
43 changes: 43 additions & 0 deletions continuity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,49 @@ func TestManager(t *testing.T) {
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

t.Run("case=can deal with duplicate cookies", func(t *testing.T) {
tc := &persisterTestCase{expected: &persisterTestPayload{"bar"}}
ts := newServer(t, p, tc)
href := ts.URL + "/" + x.NewUUID().String()

res, err := http.DefaultClient.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

// We change the key to another one
href = ts.URL + "/" + x.NewUUID().String()
req := x.NewTestHTTPRequest(t, "GET", href, nil)
require.Len(t, res.Cookies(), 1)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

tc.ro = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"})}
res, err = http.DefaultClient.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

require.Len(t, res.Cookies(), 1)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 1, "continuing the flow with a broken cookie should instruct the browser to forget it")
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))
body := ioutilx.MustReadAll(res.Body)
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

for k, tc := range []persisterTestCase{
{},
{
Expand Down
6 changes: 3 additions & 3 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ type Registry interface {
WithCSRFHandler(c nosurf.Handler)
WithCSRFTokenGenerator(cg x.CSRFToken)

HealthHandler(ctx context.Context) *healthx.Handler
CookieManager(ctx context.Context) sessions.Store
MetricsHandler() *prometheus.Handler
ContinuityCookieManager(ctx context.Context) sessions.Store
HealthHandler(ctx context.Context) *healthx.Handler
CookieManager(ctx context.Context) sessions.StoreExact
ContinuityCookieManager(ctx context.Context) sessions.StoreExact

RegisterRoutes(ctx context.Context, public *x.RouterPublic, admin *x.RouterAdmin)
RegisterPublicRoutes(ctx context.Context, public *x.RouterPublic)
Expand Down
4 changes: 2 additions & 2 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func (m *RegistryDefault) SelfServiceErrorHandler() *errorx.Handler {
return m.errorHandler
}

func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store {
func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.StoreExact {
cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...)
cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode()
cs.Options.HttpOnly = true
Expand All @@ -447,7 +447,7 @@ func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store {
return cs
}

func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.Store {
func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.StoreExact {
// To support hot reloading, this can not be instantiated only once.
cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...)
cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode()
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.16

replace (
github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3
github.com/gorilla/sessions => github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2
github.com/jackc/pgconn => github.com/jackc/pgconn v1.10.1-0.20211002123621-290ee79d1e8d
github.com/knadh/koanf => github.com/aeneasr/koanf v0.14.1-0.20211230115640-aa3902b3267a
github.com/luna-duclos/instrumentedsql => github.com/ory/instrumentedsql v1.2.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1058,10 +1058,6 @@ github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2z
github.com/gorilla/pat v0.0.0-20180118222023-199c85a7f6d1/go.mod h1:YeAe0gNeiNT5hoiZRI4yiOky6jVdNvfO2N6Kav/HmxY=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.1.2/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
Expand Down Expand Up @@ -1583,6 +1579,8 @@ github.com/ory/mail/v3 v3.0.0 h1:8LFMRj473vGahFD/ntiotWEd4S80FKYFtiZTDfOQ+sM=
github.com/ory/mail/v3 v3.0.0/go.mod h1:JGAVeZF8YAlxbaFDUHqRZAKBCSeW2w1vuxf28hFbZAw=
github.com/ory/nosurf v1.2.7 h1:YrHrbSensQyU6r6HT/V5+HPdVEgrOTMJiLoJABSBOp4=
github.com/ory/nosurf v1.2.7/go.mod h1:d4L3ZBa7Amv55bqxCBtCs63wSlyaiCkWVl4vKf3OUxA=
github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpixwHiuAwpp0Ock6khSVHkrv6lQQU=
github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/ory/viper v1.5.6/go.mod h1:TYmpFpKLxjQwvT4f0QPpkOn4sDXU1kDgAwJpgLYiQ28=
github.com/ory/viper v1.7.4/go.mod h1:T6sodNZKNGPpashUOk7EtXz2isovz8oCd57GNVkkNmE=
github.com/ory/viper v1.7.5 h1:+xVdq7SU3e1vNaCsk/ixsfxE4zylk1TJUiJrY647jUE=
Expand Down
45 changes: 26 additions & 19 deletions x/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ import (
)

// SessionPersistValues adds values to the session store and persists the changes.
func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.Store, id string, values map[string]interface{}) error {
func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id string, values map[string]interface{}) error {
// The error does not matter because in the worst case we're re-writing the session cookie.
cookie, err := s.Get(r, id)
if err != nil {
cookie = sessions.NewSession(s, id)
}

cookie, _ := s.Get(r, id)
for k, v := range values {
cookie.Values[k] = v
}
Expand All @@ -24,32 +20,43 @@ func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.Sto

// SessionGetString returns a string for the given id and key or an error if the session is invalid,
// the key does not exist, or the key value is not a string.
func SessionGetString(r *http.Request, s sessions.Store, id string, key interface{}) (string, error) {
cookie, err := s.Get(r, id)
if err != nil {
return "", errors.WithStack(err)
func SessionGetString(r *http.Request, s sessions.StoreExact, id string, key interface{}) (string, error) {
check := func(v map[interface{}]interface{}) (string, error) {
vv, ok := v[key]
if !ok {
return "", errors.Errorf("key %s does not exist in cookie: %+v", key, id)
} else if vvv, ok := vv.(string); !ok {
return "", errors.Errorf("value of key %s is not of type string in cookie", key)
} else {
return vvv, nil
}
}

if v, ok := cookie.Values[key]; !ok {
return "", errors.Errorf("key %s does not exist in cookie: %+v", key, cookie.Values)
} else if vv, ok := v.(string); !ok {
return "", errors.Errorf("value of key %s is not of type string in cookie", key)
} else {
return vv, nil
var exactErr error
cookie, err := s.GetExact(r, id, func(s *sessions.Session) bool {
_, exactErr = check(s.Values)
return exactErr == nil
})
if err != nil {
return "", err
} else if exactErr != nil {
return "", exactErr
}

return check(cookie.Values)
}

// SessionGetStringOr returns a string for the given id and key or the fallback value if the session is invalid,
// the key does not exist, or the key value is not a string.
func SessionGetStringOr(r *http.Request, s sessions.Store, id, key, fallback string) string {
func SessionGetStringOr(r *http.Request, s sessions.StoreExact, id, key, fallback string) string {
v, err := SessionGetString(r, s, id, key)
if err != nil {
return fallback
}
return v
}

func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.Store, id string) error {
func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id string) error {
cookie, err := s.Get(r, id)
if err == nil && cookie.IsNew {
// No cookie was sent in the request. We have nothing to do.
Expand All @@ -61,7 +68,7 @@ func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.Store, id s
return errors.WithStack(cookie.Save(r, w))
}

func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.Store, id, key string) error {
func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id, key string) error {
cookie, err := s.Get(r, id)
if err == nil && cookie.IsNew {
// No cookie was sent in the request. We have nothing to do.
Expand Down
43 changes: 42 additions & 1 deletion x/cookie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,48 @@ func TestSession(t *testing.T) {

w.WriteHeader(http.StatusNoContent)
})
mr(t, id)
})

t.Run("case=GetStringMultipleCookies", func(t *testing.T) {
id := "get-string-multiple"

router.GET("/set/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"multiple-string-1": "foo",
}))
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"multiple-string-2": "bar",
}))
isExpiryCorrect(t, r)
w.WriteHeader(http.StatusNoContent)
})

router.GET("/get/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
got, err := SessionGetString(r, s, sid, "multiple-string-1")
require.NoError(t, err)
assert.EqualValues(t, "foo", got)

got, err = SessionGetString(r, s, sid, "multiple-string-2")
require.NoError(t, err)
assert.EqualValues(t, "bar", got)

w.WriteHeader(http.StatusNoContent)
})

res, err := http.DefaultClient.Get(ts.URL + "/set/" + id)
require.NoError(t, err)
require.EqualValues(t, http.StatusNoContent, res.StatusCode)
require.NoError(t, res.Body.Close())

req, _ := http.NewRequest("GET", ts.URL+"/get/"+id, nil)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.EqualValues(t, http.StatusNoContent, res.StatusCode)
require.NoError(t, res.Body.Close())
})

t.Run("case=GetStringOr", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions x/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ type WriterProvider interface {
}

type CookieProvider interface {
CookieManager(ctx context.Context) sessions.Store
ContinuityCookieManager(ctx context.Context) sessions.Store
CookieManager(ctx context.Context) sessions.StoreExact
ContinuityCookieManager(ctx context.Context) sessions.StoreExact
}

type TracingProvider interface {
Expand Down

0 comments on commit 2b521e0

Please sign in to comment.