Skip to content

Commit

Permalink
fix: always rollback
Browse files Browse the repository at this point in the history
Fixes #637.
  • Loading branch information
mitar committed Jan 14, 2022
1 parent cf2c545 commit 47584f2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 39 deletions.
26 changes: 12 additions & 14 deletions handler/oauth2/flow_authorize_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func canIssueRefreshToken(c *AuthorizeExplicitGrantHandler, request fosite.Reque
return true
}

func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error {
func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) (err error) {
if !c.CanHandleTokenEndpointRequest(requester) {
return errorsx.WithStack(fosite.ErrUnknownRequest)
}
Expand Down Expand Up @@ -170,22 +170,20 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
defer func() {
if err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
}
}()

if err = c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
} else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
} else if err = c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
} else if refreshSignature != "" {
if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
if err = c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
}
Expand All @@ -198,7 +196,7 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
responder.SetExtra("refresh_token", refresh)
}

if err := storage.MaybeCommitTx(ctx, c.CoreStorage); err != nil {
if err = storage.MaybeCommitTx(ctx, c.CoreStorage); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

Expand Down
5 changes: 5 additions & 0 deletions handler/oauth2/flow_authorize_code_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
Commit(propagatedContext).
Return(errors.New("Whoops, unable to commit transaction!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
Expand Down
58 changes: 33 additions & 25 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex
}

// PopulateTokenEndpointResponse implements https://tools.ietf.org/html/rfc6749#section-6
func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error {
func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) (err error) {
if !c.CanHandleTokenEndpointRequest(requester) {
return errorsx.WithStack(fosite.ErrUnknownRequest)
}
Expand All @@ -142,27 +142,30 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
defer func() {
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
}()

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
return err
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
return err
}

if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
return err
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())

if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return err
}

if err := c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return err
}

responder.SetAccessToken(accessToken)
Expand All @@ -171,8 +174,8 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
responder.SetScopes(requester.GetGrantedScopes())
responder.SetExtra("refresh_token", refreshToken)

if err := storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, false, err)
if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return err
}

return nil
Expand All @@ -188,37 +191,42 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
// attempt the valid refresh token and the access authorization
// associated with it are both revoked.
//
func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) error {
ctx, err := storage.MaybeBeginTx(ctx, c.TokenRevocationStorage)
func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) (err error) {
ctx, err = storage.MaybeBeginTx(ctx, c.TokenRevocationStorage)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
defer func() {
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
}()

if err := c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeRefreshToken(
if err = c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil {
return err
} else if err = c.TokenRevocationStorage.RevokeRefreshToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
} else if err := c.TokenRevocationStorage.RevokeAccessToken(
return err
} else if err = c.TokenRevocationStorage.RevokeAccessToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return c.handleRefreshTokenEndpointStorageError(ctx, true, err)
return err
}

if err := storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return c.handleRefreshTokenEndpointStorageError(ctx, false, err)
if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return err
}

return nil
}

func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx context.Context, rollback bool, storageErr error) (err error) {
func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx context.Context, storageErr error) (err error) {
if storageErr == nil {
return nil
}

defer func() {
if rollback {
if rbErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rbErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(rbErr).WithDebug(rbErr.Error()))
}
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
}()

Expand Down
10 changes: 10 additions & 0 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,11 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Commit(propagatedContext).
Return(errors.New("Could not commit transaction!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
Expand Down Expand Up @@ -958,6 +963,11 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Commit(propagatedContext).
Return(fosite.ErrSerializationFailure).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
Expand Down

0 comments on commit 47584f2

Please sign in to comment.