diff --git a/HISTORY.md b/HISTORY.md index 8738a8aaf..4b226f6e4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,12 @@ This is a list of breaking changes. As long as `1.0.0` is not released, breaking changes will be addressed as minor version bumps (`0.1.0` -> `0.2.0`). +## 0.6.0 + +A bug related to refresh tokens was found. To mitigate it, a `Clone()` method has been introduced to the `fosite.Session` interface. +If you use a custom session object, this will be a breaking change. Fosite's default sessions have been upgraded and no additional +work should be required. If you use your own session struct, we encourage using package `gob/encoding` to deep-copy it in `Clone()`. + ## 0.5.0 Breaking changes: diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index d14376319..f952c8706 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -57,7 +57,7 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex return errors.Wrap(fosite.ErrInvalidRequest, "Client ID mismatch") } - request.SetSession(originalRequest.GetSession()) + request.SetSession(originalRequest.GetSession().Clone()) request.SetRequestedScopes(originalRequest.GetRequestedScopes()) for _, scope := range originalRequest.GetGrantedScopes() { request.GrantScope(scope) diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index 5378f2722..2e03967b2 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -90,7 +90,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { }, nil) }, expect: func() { - assert.Equal(t, sess, areq.Session) + assert.NotEqual(t, sess, areq.Session) assert.NotEqual(t, time.Now().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) assert.Equal(t, fosite.Arguments{"foo", "offline"}, areq.GrantedScopes) assert.Equal(t, fosite.Arguments{"foo", "bar"}, areq.Scopes) diff --git a/handler/oauth2/strategy_jwt_session.go b/handler/oauth2/strategy_jwt_session.go index ce485c268..75b1798c9 100644 --- a/handler/oauth2/strategy_jwt_session.go +++ b/handler/oauth2/strategy_jwt_session.go @@ -4,6 +4,8 @@ import ( "github.com/ory-am/fosite" "github.com/ory-am/fosite/token/jwt" "time" + "bytes" + "encoding/gob" ) type JWTSessionContainer interface { @@ -71,3 +73,13 @@ func (s *JWTSession) GetSubject() string { return s.Subject } + +func (s *JWTSession) Clone() fosite.Session { + var clone JWTSession + var mod bytes.Buffer + enc := gob.NewEncoder(&mod) + dec := gob.NewDecoder(&mod) + _ = enc.Encode(s) + _ = dec.Decode(&clone) + return &clone +} diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index a2ef76564..ecb26e706 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -2,13 +2,14 @@ package openid import ( "net/http" - + "encoding/gob" "time" "github.com/ory-am/fosite" "github.com/ory-am/fosite/token/jwt" "github.com/pkg/errors" "golang.org/x/net/context" + "bytes" ) const defaultExpiryTime = time.Hour @@ -36,6 +37,17 @@ func NewDefaultSession() *DefaultSession { } } +func (s *DefaultSession) Clone() fosite.Session { + var clone DefaultSession + var mod bytes.Buffer + enc := gob.NewEncoder(&mod) + dec := gob.NewDecoder(&mod) + _ = enc.Encode(s) + _ = dec.Decode(&clone) + return &clone +} + + func (s *DefaultSession) SetExpiresAt(key fosite.TokenType, exp time.Time) { if s.ExpiresAt == nil { s.ExpiresAt = make(map[fosite.TokenType]time.Time) diff --git a/session.go b/session.go index 2bf313fd4..f849bbb83 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,10 @@ package fosite -import "time" +import ( + "time" + "bytes" + "encoding/gob" +) // Session is an interface that is used to store session data between OAuth2 requests. It can be used to look up // when a session expires or what the subject's name was. @@ -20,6 +24,9 @@ type Session interface { // GetSubject returns the subject, if set. This is optional and only used during token introspection. GetSubject() string + + // Clone clones the session. + Clone() Session } // DefaultSession is a default implementation of the session interface. @@ -61,3 +68,13 @@ func (s *DefaultSession) GetSubject() string { return s.Subject } + +func (s *DefaultSession) Clone() Session { + var clone DefaultSession + var mod bytes.Buffer + enc := gob.NewEncoder(&mod) + dec := gob.NewDecoder(&mod) + _ = enc.Encode(s) + _ = dec.Decode(&clone) + return &clone +}