diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index 479528f2..e4b6ea34 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -22,10 +22,15 @@ package oauth2 import ( + "context" + "fmt" "net/url" "testing" "time" + "github.com/golang/mock/gomock" + "github.com/ory/fosite/internal" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -274,3 +279,321 @@ func TestRefreshFlow_PopulateTokenEndpointResponse(t *testing.T) { }) } } + +func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { + var mockTransactional *internal.MockTransactional + var mockRevocationStore *internal.MockTokenRevocationStorage + request := fosite.NewAccessRequest(&fosite.DefaultSession{}) + response := fosite.NewAccessResponse() + propagatedContext := context.Background() + + // some storage implementation that has support for transactions, notice the embedded type `storage.Transactional` + type transactionalStore struct { + storage.Transactional + TokenRevocationStorage + } + + for _, testCase := range []struct { + description string + setup func() + expectError error + }{ + { + description: "transaction should be committed successfully if no errors occur", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeRefreshToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + mockTransactional. + EXPECT(). + Commit(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "transaction should be rolled back if call to `GetRefreshTokenSession` results in an error", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(nil, fosite.ErrNotFound). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "transaction should be rolled back if call to `RevokeAccessToken` results in an error", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(errors.New("Whoops, a nasty database error occurred!")). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "transaction should be rolled back if call to `RevokeRefreshToken` results in an error", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeRefreshToken(propagatedContext, gomock.Any()). + Return(errors.New("Whoops, a nasty database error occurred!")). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "transaction should be rolled back if call to `CreateAccessTokenSession` results in an error", + setup: func() { + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeRefreshToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(errors.New("Whoops, a nasty database error occurred!")). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "transaction should be rolled back if call to `CreateRefreshTokenSession` results in an error", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeRefreshToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(errors.New("Whoops, a nasty database error occurred!")). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(nil). + Times(1) + }, + }, + { + description: "should result in a server error if transaction cannot be created", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(nil, errors.New("Could not create transaction!")). + Times(1) + }, + expectError: fosite.ErrServerError, + }, + { + description: "should result in a server error if transaction cannot be rolled back", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(nil, fosite.ErrNotFound). + Times(1) + mockTransactional. + EXPECT(). + Rollback(propagatedContext). + Return(errors.New("Could not rollback transaction!")). + Times(1) + }, + expectError: fosite.ErrServerError, + }, + { + description: "should result in a server error if transaction cannot be committed", + setup: func() { + request.GrantTypes = fosite.Arguments{"refresh_token"} + mockTransactional. + EXPECT(). + BeginTX(propagatedContext). + Return(propagatedContext, nil). + Times(1) + mockRevocationStore. + EXPECT(). + GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). + Return(request, nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeAccessToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + RevokeRefreshToken(propagatedContext, gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + mockRevocationStore. + EXPECT(). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + mockTransactional. + EXPECT(). + Commit(propagatedContext). + Return(errors.New("Could not commit transaction!")). + Times(1) + }, + expectError: fosite.ErrServerError, + }, + } { + t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTransactional = internal.NewMockTransactional(ctrl) + mockRevocationStore = internal.NewMockTokenRevocationStorage(ctrl) + testCase.setup() + + handler := RefreshTokenGrantHandler{ + // Notice how we are passing in a store that has support for transactions! + TokenRevocationStorage: transactionalStore{ + mockTransactional, + mockRevocationStore, + }, + AccessTokenStrategy: &hmacshaStrategy, + RefreshTokenStrategy: &hmacshaStrategy, + AccessTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + } + + if err := handler.PopulateTokenEndpointResponse(propagatedContext, request, response); testCase.expectError != nil { + assert.EqualError(t, errors.Cause(err), testCase.expectError.Error()) + } + }) + } +}