diff --git a/internal/api/admin.go b/internal/api/admin.go index 7df41fb65..0e5ae0cd9 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -69,9 +69,11 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, return withUser(ctx, u), nil } +// Use only after requireAuthentication, so that there is a valid user func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() db := a.db.WithContext(ctx) + user := getUser(ctx) factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") @@ -79,14 +81,14 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex observability.LogEntrySetField(r, "factor_id", factorID) - f, err := models.FindFactorByFactorID(db, factorID) + factor, err := user.FindOwnedFactorByID(db, factorID) if err != nil { if models.IsNotFoundError(err) { return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } - return withFactor(ctx, f), nil + return withFactor(ctx, factor), nil } func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index ae4c24b74..da4eec4ee 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -65,8 +65,7 @@ type UnenrollFactorResponse struct { } const ( - InvalidFactorOwnerErrorMessage = "Factor does not belong to user" - QRCodeGenerationErrorMessage = "Error generating QR Code" + QRCodeGenerationErrorMessage = "Error generating QR Code" ) func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { @@ -392,16 +391,12 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() config := a.config factor := getFactor(ctx) - user := getUser(ctx) switch factor.FactorType { case models.Phone: if !config.MFA.Phone.VerifyEnabled { return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone") } - if !factor.IsOwnedBy(user) { - return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") - } return a.challengePhoneFactor(w, r) case models.TOTP: @@ -412,9 +407,6 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { if !config.MFA.Enabled && !config.MFA.TOTP.VerifyEnabled { return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP") } - if !factor.IsOwnedBy(user) { - return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") - } return a.challengeTOTPFactor(w, r) default: return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") @@ -431,11 +423,6 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V db := a.db.WithContext(ctx) currentIP := utilities.GetIPAddress(r) - if !factor.IsOwnedBy(user) { - // TODO: Should be changed to notFoundError. Retained as internalServerError to preserve backward compatibility. - return internalServerError(InvalidFactorOwnerErrorMessage) - } - challenge, err := factor.FindChallengeByID(db, params.ChallengeID) if err != nil && models.IsNotFoundError(err) { return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") @@ -572,10 +559,6 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * db := a.db.WithContext(ctx) currentIP := utilities.GetIPAddress(r) - if !factor.IsOwnedBy(user) { - return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") - } - challenge, err := factor.FindChallengeByID(db, params.ChallengeID) if err != nil && models.IsNotFoundError(err) { return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") @@ -732,9 +715,6 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { if factor.IsVerified() && !session.IsAAL2() { return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } - if !factor.IsOwnedBy(user) { - return internalServerError(InvalidFactorOwnerErrorMessage) - } err = db.Transaction(func(tx *storage.Connection) error { var terr error diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index e08f20514..ed9a13e1b 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -569,6 +569,30 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { require.True(ts.T(), session.IsAAL2()) } +func (ts *MFATestSuite) TestChallengeFactorNotOwnedByUser() { + var buffer bytes.Buffer + email := "nomfaenabled@test.com" + password := "testpassword" + signUpResp := signUp(ts, email, password) + + friendlyName := "testfactor" + phoneNumber := "+1234567" + + otherUsersPhoneFactor := models.NewPhoneFactor(ts.TestUser, phoneNumber, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(otherUsersPhoneFactor), "Error creating factor") + + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", otherUsersPhoneFactor.ID), signUpResp.Token, buffer) + + expectedError := notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Equal(ts.T(), expectedError.ErrorCode, data.ErrorCode) + require.Equal(ts.T(), http.StatusNotFound, w.Code) + +} + func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) { ts.API.config.Mailer.Autoconfirm = true var buffer bytes.Buffer diff --git a/internal/models/user.go b/internal/models/user.go index 12ac52816..e50b9647d 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -906,6 +906,18 @@ func (u *User) SoftDeleteUserIdentities(tx *storage.Connection) error { return nil } +func (u *User) FindOwnedFactorByID(tx *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := tx.Where("user_id = ? AND id = ?", u.ID, factorID).First(&factor) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, &FactorNotFoundError{} + } + return nil, err + } + return &factor, nil +} + func obfuscateValue(id uuid.UUID, value string) string { hash := sha256.Sum256([]byte(id.String() + value)) return base64.RawURLEncoding.EncodeToString(hash[:])