Skip to content

Commit

Permalink
fix: move is owned by check to load factor (#1703)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

In `loadFactor` ensure that all factors which are loaded are owned by
the user
  • Loading branch information
J0 authored Aug 3, 2024
1 parent ac14e82 commit 701a779
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 23 deletions.
6 changes: 4 additions & 2 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,26 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
return withUser(ctx, u), nil
}

// Use only after requireAuthentication, so that there is a valid user
func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
db := a.db.WithContext(ctx)
user := getUser(ctx)
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

f, err := models.FindFactorByFactorID(db, factorID)
factor, err := user.FindOwnedFactorByID(db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
}
return withFactor(ctx, f), nil
return withFactor(ctx, factor), nil
}

func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {
Expand Down
22 changes: 1 addition & 21 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ type UnenrollFactorResponse struct {
}

const (
InvalidFactorOwnerErrorMessage = "Factor does not belong to user"
QRCodeGenerationErrorMessage = "Error generating QR Code"
QRCodeGenerationErrorMessage = "Error generating QR Code"
)

func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error {
Expand Down Expand Up @@ -392,16 +391,12 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
config := a.config
factor := getFactor(ctx)
user := getUser(ctx)

switch factor.FactorType {
case models.Phone:
if !config.MFA.Phone.VerifyEnabled {
return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone")
}
if !factor.IsOwnedBy(user) {
return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found")
}
return a.challengePhoneFactor(w, r)

case models.TOTP:
Expand All @@ -412,9 +407,6 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
if !config.MFA.Enabled && !config.MFA.TOTP.VerifyEnabled {
return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP")
}
if !factor.IsOwnedBy(user) {
return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found")
}
return a.challengeTOTPFactor(w, r)
default:
return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone")
Expand All @@ -431,11 +423,6 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V
db := a.db.WithContext(ctx)
currentIP := utilities.GetIPAddress(r)

if !factor.IsOwnedBy(user) {
// TODO: Should be changed to notFoundError. Retained as internalServerError to preserve backward compatibility.
return internalServerError(InvalidFactorOwnerErrorMessage)
}

challenge, err := factor.FindChallengeByID(db, params.ChallengeID)
if err != nil && models.IsNotFoundError(err) {
return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found")
Expand Down Expand Up @@ -572,10 +559,6 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params *
db := a.db.WithContext(ctx)
currentIP := utilities.GetIPAddress(r)

if !factor.IsOwnedBy(user) {
return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found")
}

challenge, err := factor.FindChallengeByID(db, params.ChallengeID)
if err != nil && models.IsNotFoundError(err) {
return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found")
Expand Down Expand Up @@ -732,9 +715,6 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error {
if factor.IsVerified() && !session.IsAAL2() {
return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor")
}
if !factor.IsOwnedBy(user) {
return internalServerError(InvalidFactorOwnerErrorMessage)
}

err = db.Transaction(func(tx *storage.Connection) error {
var terr error
Expand Down
24 changes: 24 additions & 0 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,30 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
require.True(ts.T(), session.IsAAL2())
}

func (ts *MFATestSuite) TestChallengeFactorNotOwnedByUser() {
var buffer bytes.Buffer
email := "[email protected]"
password := "testpassword"
signUpResp := signUp(ts, email, password)

friendlyName := "testfactor"
phoneNumber := "+1234567"

otherUsersPhoneFactor := models.NewPhoneFactor(ts.TestUser, phoneNumber, friendlyName)
require.NoError(ts.T(), ts.API.db.Create(otherUsersPhoneFactor), "Error creating factor")

w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", otherUsersPhoneFactor.ID), signUpResp.Token, buffer)

expectedError := notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")

var data HTTPError
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))

require.Equal(ts.T(), expectedError.ErrorCode, data.ErrorCode)
require.Equal(ts.T(), http.StatusNotFound, w.Code)

}

func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) {
ts.API.config.Mailer.Autoconfirm = true
var buffer bytes.Buffer
Expand Down
12 changes: 12 additions & 0 deletions internal/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,18 @@ func (u *User) SoftDeleteUserIdentities(tx *storage.Connection) error {
return nil
}

func (u *User) FindOwnedFactorByID(tx *storage.Connection, factorID uuid.UUID) (*Factor, error) {
var factor Factor
err := tx.Where("user_id = ? AND id = ?", u.ID, factorID).First(&factor)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, &FactorNotFoundError{}
}
return nil, err
}
return &factor, nil
}

func obfuscateValue(id uuid.UUID, value string) string {
hash := sha256.Sum256([]byte(id.String() + value))
return base64.RawURLEncoding.EncodeToString(hash[:])
Expand Down

0 comments on commit 701a779

Please sign in to comment.