From 250d92f9a18d38089d1bf262ef9088022a446965 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Fri, 2 Aug 2024 18:39:43 +0200 Subject: [PATCH] fix: refactor TOTP MFA into separate methods (#1698) ## What kind of change does this PR introduce? Refactors TOTP, Challenge, Enroll, and Verify into separate branches for consistency with other methods and also readability. Adds an additional check to ensure a user must own a factor in order to challenge it. --------- Co-authored-by: Kang Ming --- internal/api/mfa.go | 272 ++++++++++++++++++++++++-------------------- 1 file changed, 150 insertions(+), 122 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index affc09466..ae4c24b74 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -137,39 +137,12 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * }) } -func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { +func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { ctx := r.Context() user := getUser(ctx) - session := getSession(ctx) - config := a.config db := a.db.WithContext(ctx) - - if session == nil || user == nil { - return internalServerError("A valid session and a registered user are required to enroll a factor") - } - params := &EnrollFactorParams{} - if err := retrieveRequestParams(r, params); err != nil { - return err - } - - switch params.FactorType { - case models.Phone: - if !config.MFA.Phone.EnrollEnabled { - return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") - } - return a.enrollPhoneFactor(w, r, params) - case models.TOTP: - // Prior to the introduction of MFA.TOTP.EnrollEnabled, - // MFA.Enabled was used to configure whether TOTP was on. So - // both have to be set to false to regard the feature as - // disabled. - if !config.MFA.Enabled && !config.MFA.TOTP.EnrollEnabled { - return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") - } - default: - return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") - } - + config := a.config + session := getSession(ctx) issuer := "" if params.Issuer == "" { u, err := url.ParseRequestURI(config.SiteURL) @@ -263,6 +236,41 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { }) } +func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + config := a.config + + if session == nil || user == nil { + return internalServerError("A valid session and a registered user are required to enroll a factor") + } + params := &EnrollFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + switch params.FactorType { + case models.Phone: + if !config.MFA.Phone.EnrollEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") + } + return a.enrollPhoneFactor(w, r, params) + case models.TOTP: + // Prior to the introduction of MFA.TOTP.EnrollEnabled, + // MFA.Enabled was used to configure whether TOTP was on. So + // both have to be set to false to regard the feature as + // disabled. + if !config.MFA.Enabled && !config.MFA.TOTP.EnrollEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") + } + return a.enrollTOTPFactor(w, r, params) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp or phone") + } + +} + func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() config := a.config @@ -349,33 +357,14 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error }) } -func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { +func (a *API) challengeTOTPFactor(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() config := a.config db := a.db.WithContext(ctx) user := getUser(ctx) factor := getFactor(ctx) - ipAddress := utilities.GetIPAddress(r) - switch factor.FactorType { - case models.Phone: - if !config.MFA.Phone.VerifyEnabled { - return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone") - } - return a.challengePhoneFactor(w, r) - - case models.TOTP: - // Prior to the introduction of MFA.TOTP.VerifyEnabled, - // MFA.Enabled was used to configure whether TOTP was on. So - // both have to be set to false to regard the feature as - // disabled. - if !config.MFA.Enabled && !config.MFA.TOTP.VerifyEnabled { - return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP") - } - default: - return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") - } challenge := factor.CreateChallenge(ipAddress) if err := db.Transaction(func(tx *storage.Connection) error { @@ -399,17 +388,52 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { +func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() config := a.config + factor := getFactor(ctx) + user := getUser(ctx) + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone") + } + if !factor.IsOwnedBy(user) { + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") + } + return a.challengePhoneFactor(w, r) + + case models.TOTP: + // Prior to the introduction of MFA.TOTP.VerifyEnabled, + // MFA.Enabled was used to configure whether TOTP was on. So + // both have to be set to false to regard the feature as + // disabled. + if !config.MFA.Enabled && !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP") + } + if !factor.IsOwnedBy(user) { + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") + } + return a.challengeTOTPFactor(w, r) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") + } + +} + +func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + var err error + ctx := r.Context() user := getUser(ctx) factor := getFactor(ctx) + config := a.config db := a.db.WithContext(ctx) currentIP := utilities.GetIPAddress(r) if !factor.IsOwnedBy(user) { - return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") - + // TODO: Should be changed to notFoundError. Retained as internalServerError to preserve backward compatibility. + return internalServerError(InvalidFactorOwnerErrorMessage) } challenge, err := factor.FindChallengeByID(db, params.ChallengeID) @@ -429,17 +453,24 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * } return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } - otpCode, shouldReEncrypt, err := challenge.GetOtpCode(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + + secret, shouldReEncrypt, err := factor.GetSecret(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) if err != nil { return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) } - valid := subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 + + valid, verr := totp.ValidateCustom(params.Code, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + if config.Hook.MFAVerificationAttempt.Enabled { input := hooks.MFAVerificationAttemptInput{ - UserID: user.ID, - FactorID: factor.ID, - FactorType: factor.FactorType, - Valid: valid, + UserID: user.ID, + FactorID: factor.ID, + Valid: valid, } output := hooks.MFAVerificationAttemptOutput{} @@ -462,15 +493,15 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * } if !valid { if shouldReEncrypt && config.Security.DBEncryption.Encrypt { - if err := challenge.SetOtpCode(otpCode, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + if err := factor.SetSecret(secret, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { return err } - if err := db.UpdateOnly(challenge, "otp_code"); err != nil { + if err := db.UpdateOnly(factor, "secret"); err != nil { return err } } - return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) } var token *AccessTokenResponse @@ -480,7 +511,6 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, "challenge_id": challenge.ID, - "factor_type": factor.FactorType, }); terr != nil { return terr } @@ -492,12 +522,23 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * return terr } } + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + es, terr := crypto.NewEncryptedString(factor.ID.String(), []byte(secret), config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) + if terr != nil { + return terr + } + + factor.Secret = es.String() + if terr := tx.UpdateOnly(factor, "secret"); terr != nil { + return terr + } + } user, terr = models.FindUserByID(tx, user.ID) if terr != nil { return terr } - token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAPhone, models.GrantParams{ + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.TOTPSignIn, models.GrantParams{ FactorID: &factor.ID, }) if terr != nil { @@ -520,45 +561,19 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) return sendJSON(w, http.StatusOK, token) + } -func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { - var err error +func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { ctx := r.Context() + config := a.config user := getUser(ctx) factor := getFactor(ctx) - config := a.config db := a.db.WithContext(ctx) - - params := &VerifyFactorParams{} - if err := retrieveRequestParams(r, params); err != nil { - return err - } - - switch factor.FactorType { - case models.Phone: - if !config.MFA.Phone.VerifyEnabled { - return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone") - } - if params.Code == "" { - return badRequestError(ErrorCodeValidationFailed, "Code needs to be non-empty") - } - return a.verifyPhoneFactor(w, r, params) - case models.TOTP: - if !config.MFA.TOTP.VerifyEnabled { - return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP") - } - if params.Code == "" { - return badRequestError(ErrorCodeValidationFailed, "Code needs to be non-empty") - } - default: - return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") - } - currentIP := utilities.GetIPAddress(r) if !factor.IsOwnedBy(user) { - return internalServerError(InvalidFactorOwnerErrorMessage) + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor not found") } challenge, err := factor.FindChallengeByID(db, params.ChallengeID) @@ -578,24 +593,17 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } - - secret, shouldReEncrypt, err := factor.GetSecret(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + otpCode, shouldReEncrypt, err := challenge.GetOtpCode(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) if err != nil { return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) } - - valid, verr := totp.ValidateCustom(params.Code, secret, time.Now().UTC(), totp.ValidateOpts{ - Period: 30, - Skew: 1, - Digits: otp.DigitsSix, - Algorithm: otp.AlgorithmSHA1, - }) - + valid := subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 if config.Hook.MFAVerificationAttempt.Enabled { input := hooks.MFAVerificationAttemptInput{ - UserID: user.ID, - FactorID: factor.ID, - Valid: valid, + UserID: user.ID, + FactorID: factor.ID, + FactorType: factor.FactorType, + Valid: valid, } output := hooks.MFAVerificationAttemptOutput{} @@ -618,15 +626,15 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } if !valid { if shouldReEncrypt && config.Security.DBEncryption.Encrypt { - if err := factor.SetSecret(secret, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + if err := challenge.SetOtpCode(otpCode, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { return err } - if err := db.UpdateOnly(factor, "secret"); err != nil { + if err := db.UpdateOnly(challenge, "otp_code"); err != nil { return err } } - return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") } var token *AccessTokenResponse @@ -636,6 +644,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, "challenge_id": challenge.ID, + "factor_type": factor.FactorType, }); terr != nil { return terr } @@ -647,23 +656,12 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { return terr } } - if shouldReEncrypt && config.Security.DBEncryption.Encrypt { - es, terr := crypto.NewEncryptedString(factor.ID.String(), []byte(secret), config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) - if terr != nil { - return terr - } - - factor.Secret = es.String() - if terr := tx.UpdateOnly(factor, "secret"); terr != nil { - return terr - } - } user, terr = models.FindUserByID(tx, user.ID) if terr != nil { return terr } - token, terr = a.updateMFASessionAndClaims(r, tx, user, models.TOTPSignIn, models.GrantParams{ + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAPhone, models.GrantParams{ FactorID: &factor.ID, }) if terr != nil { @@ -686,6 +684,36 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) return sendJSON(w, http.StatusOK, token) +} + +func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + factor := getFactor(ctx) + config := a.config + + params := &VerifyFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.Code == "" { + return badRequestError(ErrorCodeValidationFailed, "Code needs to be non-empty") + } + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA verification is disabled for Phone") + } + + return a.verifyPhoneFactor(w, r, params) + case models.TOTP: + if !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA verification is disabled for TOTP") + } + return a.verifyTOTPFactor(w, r, params) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be TOTP or Phone") + } }