Skip to content

Commit

Permalink
feat: remove legacy lookup in users for one_time_tokens (phase II)
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed May 3, 2024
1 parent 2037c1f commit 9c90f81
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 74 deletions.
4 changes: 4 additions & 0 deletions internal/api/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func (ts *ExternalTestSuite) createUser(providerId string, email string, name st
ts.Require().NoError(err, "Error making new user")
ts.Require().NoError(ts.API.db.Create(u), "Error creating user")

if confirmationToken != "" {
ts.Require().NoError(models.CreateOneTimeToken(ts.API.db, u.ID, email, u.ConfirmationToken, models.ConfirmationToken), "Error creating one-time confirmation/invite token")
}

i, err := models.NewIdentity(u, "email", map[string]interface{}{
"sub": u.ID.String(),
"email": email,
Expand Down
19 changes: 1 addition & 18 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@ import (
type VerifyVariant int

const (
VerifyWithoutOTT VerifyVariant = iota
VerifyWithOTT
VerifyWithOTT VerifyVariant = iota
)

func (v VerifyVariant) String() string {
switch v {
case VerifyWithoutOTT:
return "WithoutOTT"

case VerifyWithOTT:
return "WithOTT"

Expand Down Expand Up @@ -71,7 +67,6 @@ func (ts *VerifyTestSuite) SetupTest() {

func (ts *VerifyTestSuite) VerifyWithVariants(fn func(variant VerifyVariant)) {
variants := []VerifyVariant{
VerifyWithoutOTT,
VerifyWithOTT,
}

Expand Down Expand Up @@ -146,10 +141,6 @@ func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() {

recoveryToken := u.RecoveryToken

if variant == VerifyWithoutOTT {
require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID))
}

reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken)
req = httptest.NewRequest(http.MethodGet, reqURL, nil)

Expand Down Expand Up @@ -249,10 +240,6 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() {
currentTokenHash := u.EmailChangeTokenCurrent
newTokenHash := u.EmailChangeTokenNew

if variant == VerifyWithoutOTT {
require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID))
}

u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

Expand Down Expand Up @@ -1002,10 +989,6 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() {
require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew))
require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken))

if variant == VerifyWithoutOTT {
require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID))
}

require.NoError(ts.T(), ts.API.db.Update(u))

var buffer bytes.Buffer
Expand Down
62 changes: 6 additions & 56 deletions internal/models/one_time_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,99 +178,57 @@ func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...On
// FindUserByConfirmationToken finds users with the matching confirmation token.
func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) {
ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}

if ott == nil {
user, err := findUser(tx, "(confirmation_token = ? or recovery_token = ?) and is_sso_user = false", token, token)
if err != nil {
if IsNotFoundError(err) {
return nil, ConfirmationOrRecoveryTokenNotFoundError{}
} else {
return nil, err
}
}

return user, nil
}

return FindUserByID(tx, ott.UserID)
}

// FindUserByConfirmationToken finds users with the matching confirmation token.
func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) {
ott, err := FindOneTimeToken(tx, token, ConfirmationToken)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}

if ott == nil {
user, err := findUser(tx, "confirmation_token = ? and is_sso_user = false", token)
if err != nil {
if IsNotFoundError(err) {
return nil, ConfirmationTokenNotFoundError{}
} else {
return nil, err
}
}

return user, nil
}

return FindUserByID(tx, ott.UserID)
}

// FindUserByRecoveryToken finds a user with the matching recovery token.
func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) {
ott, err := FindOneTimeToken(tx, token, RecoveryToken)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}

if ott == nil {
return findUser(tx, "recovery_token = ? and is_sso_user = false", token)
}

return FindUserByID(tx, ott.UserID)
}

// FindUserByEmailChangeToken finds a user with the matching email change token.
func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) {
ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}

if ott == nil {
return findUser(tx, "is_sso_user = false and (email_change_token_current = ? or email_change_token_new = ?)", token, token)
}

return FindUserByID(tx, ott.UserID)
}

// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience.
func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) {
ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}

if ott == nil {
ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent)
if err != nil && !IsNotFoundError(err) {
if err != nil {
return nil, err
}
}

if ott == nil {
return findUser(
tx,
"instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false and (email_change_token_current = 'pkce_' || ? or email_change_token_current = ?)",
uuid.Nil, strings.ToLower(email), aud, token, token,
)
}

user, err := FindUserByID(tx, ott.UserID)
if err != nil {
return nil, err
Expand All @@ -297,14 +255,6 @@ func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, a
}
}

if ott == nil {
return findUser(
tx,
"instance_id = ? and LOWER(email_change) = ? and aud = ? and is_sso_user = false and (email_change_token_new = 'pkce_' || ? or email_change_token_new = ?)",
uuid.Nil, strings.ToLower(email), aud, token, token,
)
}

user, err := FindUserByID(tx, ott.UserID)
if err != nil {
return nil, err
Expand Down

0 comments on commit 9c90f81

Please sign in to comment.