Skip to content

Commit

Permalink
feat: refactor generate accesss token to take in request (#1531)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

In support of the use of HTTP Hook with Custom Access Token Extension
Point.

We need to take in a request in order to support the Custom Access Token
Hook. We use the request in the Custom access hook depends on the
request to fetch the global logger. We refactor `generateAccessToken`
and a wrapping method, `issueRefreshToken`, to take in a request to
support this.

We also add a dummy request to the tests to support this change.
Supports #1528 - branched out as a separate PR so as not to bloat the
main PR with peripheral changes.
  • Loading branch information
J0 authored Apr 12, 2024
1 parent 56b1c45 commit e4f2b59
Show file tree
Hide file tree
Showing 16 changed files with 36 additions and 28 deletions.
2 changes: 1 addition & 1 deletion internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return terr
}
token, terr = a.issueRefreshToken(ctx, tx, newUser, models.Anonymous, grantParams)
token, terr = a.issueRefreshToken(r, tx, newUser, models.Anonymous, grantParams)
if terr != nil {
return terr
}
Expand Down
5 changes: 3 additions & 2 deletions internal/api/audit_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -54,7 +53,9 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string {
require.NoError(ts.T(), ts.API.db.Create(session))

var token string
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.PasswordGrant)

req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re

terr = tx.Update(flowState)
} else {
token, terr = a.issueRefreshToken(ctx, tx, user, models.OAuth, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams)
}

if terr != nil {
Expand Down
3 changes: 2 additions & 1 deletion internal/api/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ func (ts *IdentityTestSuite) generateAccessTokenAndSession(ctx context.Context,
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)
req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)
return token

Expand Down
5 changes: 3 additions & 2 deletions internal/api/invite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -65,7 +64,9 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string {
session, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.Invite)

req := httptest.NewRequest(http.MethodPost, "/invite", nil)
token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.Invite)

require.NoError(ts.T(), err, "Error generating access token")

Expand Down
5 changes: 3 additions & 2 deletions internal/api/logout_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -46,7 +45,9 @@ func (ts *LogoutTestSuite) SetupTest() {
s, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))
t, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)

req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
t, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)
ts.token = t
}
Expand Down
4 changes: 3 additions & 1 deletion internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ func (ts *MFATestSuite) SetupTest() {
}

func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string {
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.TOTPSignIn)
// Not an actual path. Dummy request to simulate a signup request that we can use in generateAccessToken
req := httptest.NewRequest(http.MethodPost, "/factors", nil)
token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.TOTPSignIn)
require.NoError(ts.T(), err, "Error generating access token")
return token
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -163,7 +162,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() {
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)
req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)

cases := []struct {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
}
}

token, terr = a.issueRefreshToken(ctx, tx, user, models.SSOSAML, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams)

if terr != nil {
return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr)
Expand Down
2 changes: 1 addition & 1 deletion internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
token, terr = a.issueRefreshToken(ctx, tx, user, models.PasswordGrant, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams)

if terr != nil {
return terr
Expand Down
13 changes: 7 additions & 6 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
}); terr != nil {
return terr
}
token, terr = a.issueRefreshToken(ctx, tx, user, models.PasswordGrant, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -269,7 +269,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
}); terr != nil {
return terr
}
token, terr = a.issueRefreshToken(ctx, tx, user, authMethod, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, authMethod, grantParams)
if terr != nil {
return oauthError("server_error", terr.Error())
}
Expand All @@ -291,7 +291,8 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
return sendJSON(w, http.StatusOK, token)
}

func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
ctx := r.Context()
config := a.config
if sessionId == nil {
return "", 0, internalServerError("Session is required to issue access token")
Expand Down Expand Up @@ -366,7 +367,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u
return signed, expiresAt, nil
}

func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
config := a.config

now := time.Now()
Expand All @@ -389,7 +390,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
return terr
}

tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod)
tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, refreshToken.SessionId, authenticationMethod)
if terr != nil {
// Account for Hook Error
httpErr, ok := terr.(*HTTPError)
Expand Down Expand Up @@ -455,7 +456,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return err
}

tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn)
tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, &session.ID, models.TOTPSignIn)
if terr != nil {
httpErr, ok := terr.(*HTTPError)
if ok {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user, models.OAuth, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams)
if terr != nil {
return terr
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
issuedToken = newToken
}

tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh)
tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, issuedToken.SessionId, models.TokenRefresh)
if terr != nil {
httpErr, ok := terr.(*HTTPError)
if ok {
Expand Down
6 changes: 4 additions & 2 deletions internal/api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func (ts *UserTestSuite) SetupTest() {
}

func (ts *UserTestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string {
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.PasswordGrant)
req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")
return token
}
Expand All @@ -57,7 +58,8 @@ func (ts *UserTestSuite) generateAccessTokenAndSession(user *models.User) string
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, &session.ID, models.PasswordGrant)
req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil)
token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, &session.ID, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")
return token
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
}

if isImplicitFlow(flowType) {
token, terr = a.issueRefreshToken(ctx, tx, user, models.OTP, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -278,7 +278,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP
if terr := tx.Reload(user); terr != nil {
return terr
}
token, terr = a.issueRefreshToken(ctx, tx, user, models.OTP, grantParams)
token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams)
if terr != nil {
return terr
}
Expand Down
3 changes: 1 addition & 2 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
mail "github.com/supabase/auth/internal/mailer"
Expand Down Expand Up @@ -184,7 +183,7 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() {
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))

token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.MagicLink)
token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

Expand Down

0 comments on commit e4f2b59

Please sign in to comment.