From e4f2b59e8e1f8158b6461a384349f1a32cc1bf9a Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Fri, 12 Apr 2024 17:04:55 +0900 Subject: [PATCH] feat: refactor generate accesss token to take in request (#1531) ## What kind of change does this PR introduce? In support of the use of HTTP Hook with Custom Access Token Extension Point. We need to take in a request in order to support the Custom Access Token Hook. We use the request in the Custom access hook depends on the request to fetch the global logger. We refactor `generateAccessToken` and a wrapping method, `issueRefreshToken`, to take in a request to support this. We also add a dummy request to the tests to support this change. Supports #1528 - branched out as a separate PR so as not to bloat the main PR with peripheral changes. --- internal/api/anonymous.go | 2 +- internal/api/audit_test.go | 5 +++-- internal/api/external.go | 2 +- internal/api/identity_test.go | 3 ++- internal/api/invite_test.go | 5 +++-- internal/api/logout_test.go | 5 +++-- internal/api/mfa_test.go | 4 +++- internal/api/phone_test.go | 4 ++-- internal/api/samlacs.go | 2 +- internal/api/signup.go | 2 +- internal/api/token.go | 13 +++++++------ internal/api/token_oidc.go | 2 +- internal/api/token_refresh.go | 2 +- internal/api/user_test.go | 6 ++++-- internal/api/verify.go | 4 ++-- internal/api/verify_test.go | 3 +-- 16 files changed, 36 insertions(+), 28 deletions(-) diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 4024d5947..fada3cb65 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -40,7 +40,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - token, terr = a.issueRefreshToken(ctx, tx, newUser, models.Anonymous, grantParams) + token, terr = a.issueRefreshToken(r, tx, newUser, models.Anonymous, grantParams) if terr != nil { return terr } diff --git a/internal/api/audit_test.go b/internal/api/audit_test.go index 22518c704..8779ab678 100644 --- a/internal/api/audit_test.go +++ b/internal/api/audit_test.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/json" "fmt" "net/http" @@ -54,7 +53,9 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string { require.NoError(ts.T(), ts.API.db.Create(session)) var token string - token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.PasswordGrant) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.PasswordGrant) require.NoError(ts.T(), err, "Error generating access token") p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} diff --git a/internal/api/external.go b/internal/api/external.go index 8a00e3d25..25a41890a 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -232,7 +232,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re terr = tx.Update(flowState) } else { - token, terr = a.issueRefreshToken(ctx, tx, user, models.OAuth, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) } if terr != nil { diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index b8f5e4305..89b92b3e0 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -219,7 +219,8 @@ func (ts *IdentityTestSuite) generateAccessTokenAndSession(ctx context.Context, require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(s)) - token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant) + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) require.NoError(ts.T(), err) return token diff --git a/internal/api/invite_test.go b/internal/api/invite_test.go index aa9002f42..1ced4caeb 100644 --- a/internal/api/invite_test.go +++ b/internal/api/invite_test.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -65,7 +64,9 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string { session, err := models.NewSession(u.ID, nil) require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(session)) - token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.Invite) + + req := httptest.NewRequest(http.MethodPost, "/invite", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.Invite) require.NoError(ts.T(), err, "Error generating access token") diff --git a/internal/api/logout_test.go b/internal/api/logout_test.go index 49064cf5a..3a4094109 100644 --- a/internal/api/logout_test.go +++ b/internal/api/logout_test.go @@ -1,7 +1,6 @@ package api import ( - "context" "fmt" "net/http" "net/http/httptest" @@ -46,7 +45,9 @@ func (ts *LogoutTestSuite) SetupTest() { s, err := models.NewSession(u.ID, nil) require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(s)) - t, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + t, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) require.NoError(ts.T(), err) ts.token = t } diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 216d74043..991cc52f9 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -95,7 +95,9 @@ func (ts *MFATestSuite) SetupTest() { } func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string { - token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.TOTPSignIn) + // Not an actual path. Dummy request to simulate a signup request that we can use in generateAccessToken + req := httptest.NewRequest(http.MethodPost, "/factors", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.TOTPSignIn) require.NoError(ts.T(), err, "Error generating access token") return token } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index b00913ca8..d451df48b 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -163,7 +162,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(s)) - token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant) + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) require.NoError(ts.T(), err) cases := []struct { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index d82117748..d50e16a29 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -283,7 +283,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { } } - token, terr = a.issueRefreshToken(ctx, tx, user, models.SSOSAML, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams) if terr != nil { return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr) diff --git a/internal/api/signup.go b/internal/api/signup.go index 07fff7667..7584719a8 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -325,7 +325,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }); terr != nil { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user, models.PasswordGrant, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) if terr != nil { return terr diff --git a/internal/api/token.go b/internal/api/token.go index 02fcbd1d5..3f6a0eb1f 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -199,7 +199,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri }); terr != nil { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user, models.PasswordGrant, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) if terr != nil { return terr } @@ -269,7 +269,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) }); terr != nil { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user, authMethod, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, authMethod, grantParams) if terr != nil { return oauthError("server_error", terr.Error()) } @@ -291,7 +291,8 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return sendJSON(w, http.StatusOK, token) } -func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { +func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { + ctx := r.Context() config := a.config if sessionId == nil { return "", 0, internalServerError("Session is required to issue access token") @@ -366,7 +367,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u return signed, expiresAt, nil } -func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { +func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { config := a.config now := time.Now() @@ -389,7 +390,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u return terr } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod) + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, refreshToken.SessionId, authenticationMethod) if terr != nil { // Account for Hook Error httpErr, ok := terr.(*HTTPError) @@ -455,7 +456,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn) + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, &session.ID, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 0574c3bb8..bb4370402 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -221,7 +221,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R return terr } - token, terr = a.issueRefreshToken(ctx, tx, user, models.OAuth, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) if terr != nil { return terr } diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 65bbbb031..5c9d1984d 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -210,7 +210,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h issuedToken = newToken } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh) + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, issuedToken.SessionId, models.TokenRefresh) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/user_test.go b/internal/api/user_test.go index 9e3f63b12..ac97d9c24 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -47,7 +47,8 @@ func (ts *UserTestSuite) SetupTest() { } func (ts *UserTestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string { - token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.PasswordGrant) + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.PasswordGrant) require.NoError(ts.T(), err, "Error generating access token") return token } @@ -57,7 +58,8 @@ func (ts *UserTestSuite) generateAccessTokenAndSession(user *models.User) string require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(session)) - token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, &session.ID, models.PasswordGrant) + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, &session.ID, models.PasswordGrant) require.NoError(ts.T(), err, "Error generating access token") return token } diff --git a/internal/api/verify.go b/internal/api/verify.go index 5d1a43aed..ef16bf178 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -179,7 +179,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } if isImplicitFlow(flowType) { - token, terr = a.issueRefreshToken(ctx, tx, user, models.OTP, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) if terr != nil { return terr } @@ -278,7 +278,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP if terr := tx.Reload(user); terr != nil { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user, models.OTP, grantParams) + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) if terr != nil { return terr } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 4e7cbdd0d..a2be255de 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "context" "encoding/json" "fmt" mail "github.com/supabase/auth/internal/mailer" @@ -184,7 +183,7 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(session)) - token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.MagicLink) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) require.NoError(ts.T(), err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))