forked from supabase/auth
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: unlink identity bugs (supabase#1475)
- Loading branch information
1 parent
311cde8
commit 73e8d87
Showing
3 changed files
with
158 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,13 @@ package api | |
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/gofrs/uuid" | ||
"github.com/stretchr/testify/require" | ||
"github.com/stretchr/testify/suite" | ||
"github.com/supabase/auth/internal/api/provider" | ||
|
@@ -34,9 +37,10 @@ func (ts *IdentityTestSuite) SetupTest() { | |
models.TruncateAll(ts.API.db) | ||
|
||
// Create user | ||
u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) | ||
u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil) | ||
require.NoError(ts.T(), err, "Error creating test user model") | ||
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") | ||
require.NoError(ts.T(), u.Confirm(ts.API.db)) | ||
|
||
// Create identity | ||
i, err := models.NewIdentity(u, "email", map[string]interface{}{ | ||
|
@@ -45,10 +49,31 @@ func (ts *IdentityTestSuite) SetupTest() { | |
}) | ||
require.NoError(ts.T(), err) | ||
require.NoError(ts.T(), ts.API.db.Create(i)) | ||
|
||
// Create user with 2 identities | ||
u, err = models.NewUser("123456789", "[email protected]", "password", ts.Config.JWT.Aud, nil) | ||
require.NoError(ts.T(), err, "Error creating test user model") | ||
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") | ||
require.NoError(ts.T(), u.Confirm(ts.API.db)) | ||
require.NoError(ts.T(), u.ConfirmPhone(ts.API.db)) | ||
|
||
i, err = models.NewIdentity(u, "email", map[string]interface{}{ | ||
"sub": u.ID.String(), | ||
"email": u.GetEmail(), | ||
}) | ||
require.NoError(ts.T(), err) | ||
require.NoError(ts.T(), ts.API.db.Create(i)) | ||
|
||
i2, err := models.NewIdentity(u, "phone", map[string]interface{}{ | ||
"sub": u.ID.String(), | ||
"phone": u.GetPhone(), | ||
}) | ||
require.NoError(ts.T(), err) | ||
require.NoError(ts.T(), ts.API.db.Create(i2)) | ||
} | ||
|
||
func (ts *IdentityTestSuite) TestLinkIdentityToUser() { | ||
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) | ||
u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) | ||
require.NoError(ts.T(), err) | ||
ctx := withTargetUser(context.Background(), u) | ||
|
||
|
@@ -79,3 +104,112 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { | |
require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) | ||
require.Nil(ts.T(), u) | ||
} | ||
|
||
func (ts *IdentityTestSuite) TestUnlinkIdentityError() { | ||
ts.Config.Security.ManualLinkingEnabled = true | ||
userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud) | ||
require.NoError(ts.T(), err) | ||
|
||
userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud) | ||
require.NoError(ts.T(), err) | ||
cases := []struct { | ||
desc string | ||
user *models.User | ||
identityId uuid.UUID | ||
expectedError *HTTPError | ||
}{ | ||
{ | ||
desc: "User must have at least 1 identity after unlinking", | ||
user: userWithOneIdentity, | ||
identityId: userWithOneIdentity.Identities[0].ID, | ||
expectedError: badRequestError("User must have at least 1 identity after unlinking"), | ||
}, | ||
{ | ||
desc: "Identity doesn't exist", | ||
user: userWithTwoIdentities, | ||
identityId: uuid.Must(uuid.NewV4()), | ||
expectedError: badRequestError("Identity doesn't exist"), | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
ts.Run(c.desc, func() { | ||
token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, c.user, nil, models.PasswordGrant) | ||
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil) | ||
require.NoError(ts.T(), err) | ||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) | ||
w := httptest.NewRecorder() | ||
|
||
ts.API.handler.ServeHTTP(w, req) | ||
require.Equal(ts.T(), c.expectedError.Code, w.Code) | ||
|
||
var data HTTPError | ||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) | ||
require.Equal(ts.T(), c.expectedError.Message, data.Message) | ||
}) | ||
} | ||
} | ||
|
||
func (ts *IdentityTestSuite) TestUnlinkIdentity() { | ||
ts.Config.Security.ManualLinkingEnabled = true | ||
|
||
// we want to test 2 cases here: unlinking a phone identity and email identity from a user | ||
cases := []struct { | ||
desc string | ||
// the provider to be unlinked | ||
provider string | ||
// the remaining provider that should be linked to the user | ||
providerRemaining string | ||
}{ | ||
{ | ||
desc: "Unlink phone identity successfully", | ||
provider: "phone", | ||
providerRemaining: "email", | ||
}, | ||
{ | ||
desc: "Unlink email identity successfully", | ||
provider: "email", | ||
providerRemaining: "phone", | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
ts.Run(c.desc, func() { | ||
// teardown and reset the state of the db to prevent running into errors | ||
ts.SetupTest() | ||
u, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud) | ||
require.NoError(ts.T(), err) | ||
|
||
identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider) | ||
require.NoError(ts.T(), err) | ||
|
||
token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant) | ||
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil) | ||
require.NoError(ts.T(), err) | ||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) | ||
w := httptest.NewRecorder() | ||
ts.API.handler.ServeHTTP(w, req) | ||
require.Equal(ts.T(), http.StatusOK, w.Code) | ||
|
||
// sanity checks | ||
u, err = models.FindUserByID(ts.API.db, u.ID) | ||
require.NoError(ts.T(), err) | ||
require.Len(ts.T(), u.Identities, 1) | ||
require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining) | ||
|
||
// conditional checks depending on the provider that was unlinked | ||
switch c.provider { | ||
case "phone": | ||
require.Equal(ts.T(), "", u.GetPhone()) | ||
require.Nil(ts.T(), u.PhoneConfirmedAt) | ||
case "email": | ||
require.Equal(ts.T(), "", u.GetEmail()) | ||
require.Nil(ts.T(), u.EmailConfirmedAt) | ||
} | ||
|
||
// user still has a phone / email identity linked so it should not be unconfirmed | ||
require.NotNil(ts.T(), u.ConfirmedAt) | ||
}) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters