From a156a5b6d4848cd45cbd442a737cda19363e42ca Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Sun, 17 Mar 2024 11:53:44 +0800 Subject: [PATCH 01/11] feat: support invocation of http hooks --- go.mod | 5 +- go.sum | 2 + internal/api/errorcodes.go | 1 + internal/api/errors.go | 4 + internal/api/hooks.go | 191 ++++++++++++++++++++++++++-- internal/api/hooks_test.go | 144 +++++++++++++++++++++ internal/api/mfa.go | 2 +- internal/api/otp.go | 2 +- internal/api/phone.go | 25 +++- internal/api/phone_test.go | 4 +- internal/api/reauthenticate.go | 2 +- internal/api/resend.go | 4 +- internal/api/signup.go | 2 +- internal/api/token.go | 4 +- internal/api/user.go | 2 +- internal/conf/configuration.go | 13 +- internal/conf/configuration_test.go | 2 + internal/crypto/crypto.go | 28 ++++ internal/hooks/auth_hooks.go | 27 ++++ 19 files changed, 435 insertions(+), 29 deletions(-) create mode 100644 internal/api/hooks_test.go diff --git a/go.mod b/go.mod index 0baa8ffc4..9a7ef3f74 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/fatih/structs v1.1.0 github.com/gobuffalo/pop/v6 v6.1.1 github.com/jackc/pgx/v4 v4.18.2 + github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747 github.com/xeipuuv/gojsonschema v1.2.0 @@ -146,4 +147,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.21 +go 1.21.0 + +toolchain go1.21.6 diff --git a/go.sum b/go.sum index 8783b5991..6b5c469ea 100644 --- a/go.sum +++ b/go.sum @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index 45dec0dd7..8af111bde 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -74,4 +74,5 @@ const ( ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" + ErrorHookTimeout ErrorCode = "hook_timeout" ) diff --git a/internal/api/errors.go b/internal/api/errors.go index cc6ba877b..216c9944e 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -96,6 +96,10 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } +func gatewayTimeoutError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusGatewayTimeout, errorCode, fmtString, args...) +} + // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { HTTPStatus int `json:"code"` // do not rename the JSON tags! diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 5368339d8..256667db0 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -1,17 +1,36 @@ package api import ( + "bytes" "context" "encoding/json" "fmt" + "io" + "net" "net/http" + "net/http/httptrace" + "strings" + "time" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/observability" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + + "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/storage" ) -func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { +const ( + DefaultHTTPHookTimeout = 5 * time.Second + DefaultHTTPHookRetries = 3 + HTTPHookBackoffDuration = 2 * time.Second +) + +func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { db := a.db.WithContext(ctx) request, err := json.Marshal(input) @@ -55,12 +74,168 @@ func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, return response, nil } -// invokeHook invokes the hook code. tx can be nil, in which case a new +func readBodyWithLimit(rsp *http.Response) ([]byte, error) { + defer rsp.Body.Close() + + const limit = 20 * 1024 // 20KB + limitedReader := io.LimitedReader{R: rsp.Body, N: limit} + + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + + if limitedReader.N <= 0 { + // Attempt to read one more byte to check if we're exactly at the limit or over + _, err := rsp.Body.Read(make([]byte, 1)) + if err == nil { + // If we could read more, then the payload was too large + return nil, fmt.Errorf("payload too large") + } + } + + return body, nil +} + +func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + log := observability.GetLogEntry(r) + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + start := time.Now() + for i := 0; i < DefaultHTTPHookRetries; i++ { + hookLog.Infof("invocation attempt: %d", i) + if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout { + return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout") + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + return nil, internalServerError("Failed to make request object").WithInternalError(err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + + watcher, req := watchForConnection(req) + rsp, err := client.Do(req) + + if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval") + + } else { + return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + if rsp.Body == nil { + return nil, nil + } + body, err := readBodyWithLimit(rsp) + if err != nil { + return nil, err + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to swtich to time duration + if retryAfterHeader != "" { + continue + } + return []byte{}, internalServerError("Service currently unavailable") + case http.StatusBadRequest: + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook") + case http.StatusUnauthorized: + return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token") + default: + return []byte{}, internalServerError("Error executing Hook") + } + } + return nil, internalServerError("error executing hook") +} + +func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) { + w := new(connectionWatcher) + t := &httptrace.ClientTrace{ + GotConn: w.GotConn, + } + + req = req.WithContext(httptrace.WithClientTrace(req.Context(), t)) + return w, req +} + +type connectionWatcher struct { + gotConn bool +} + +func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { + c.gotConn = true +} + +func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error { + switch input.(type) { + case *hooks.CustomSMSProviderInput: + hookOutput, ok := output.(*hooks.CustomSMSProviderOutput) + if !ok { + panic("output should be *hooks.CustomSMSProviderOutput") + } + var response []byte + var err error + + if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil { + return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err) + } + if err != nil { + return err + } + + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err) + } + fmt.Printf("%v", hookOutput) + + default: + panic("unknown HTTP hook type") + } + return nil +} + +// invokePostgresHook invokes the hook code. tx can be nil, in which case a new // transaction is opened. If calling invokeHook within a transaction, always -// pass the current transaciton, as pool-exhaustion deadlocks are very easy to +// pass the current transaction, as pool-exhaustion deadlocks are very easy to // trigger. -func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error { +func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error { config := a.config + // Switch based on hook type switch input.(type) { case *hooks.MFAVerificationAttemptInput: hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) @@ -68,7 +243,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.MFAVerificationAttemptOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) } @@ -94,7 +269,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.PasswordVerificationAttemptOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking password verification hook.").WithInternalError(err) } @@ -120,7 +295,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.CustomAccessTokenOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil { return internalServerError("Error invoking access token hook.").WithInternalError(err) } @@ -155,6 +330,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out return nil default: - panic("unknown hook input type") + panic("unknown Postgres hook input type") } } diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go new file mode 100644 index 000000000..38ed6d49d --- /dev/null +++ b/internal/api/hooks_test.go @@ -0,0 +1,144 @@ +package api + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type HooksTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestHooks(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &HooksTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *HooksTestSuite) TestRunHTTPHook() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + testCases := []struct { + description string + mockResponse interface{} + status int + expectError bool + }{ + { + description: "Successful Post request with delay", + mockResponse: successOutput, + status: http.StatusOK, + expectError: false, + }, + { + description: "Too many requests without retry header should not retry", + status: http.StatusGatewayTimeout, + expectError: true, + }, + } + + for _, tc := range testCases { + ts.Run(tc.description, func() { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse) + + var output hooks.CustomSMSProviderOutput + req, _ := http.NewRequest("POST", ts.Config.Hook.CustomSMSProvider.URI, nil) + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + + if !tc.expectError { + require.NoError(ts.T(), err) + if body != nil { + require.NoError(ts.T(), json.Unmarshal(body, &output)) + require.True(ts.T(), output.Success) + } + } else { + require.Error(ts.T(), err) + } + require.True(ts.T(), gock.IsDone()) + }) + } +} + +func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusTooManyRequests). + SetHeader("retry-after", "true") + + // Simulate an additional response for the retry attempt + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput) + + var output hooks.CustomSMSProviderOutput + + // Simulate the original HTTP request which triggered the hook + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + require.NoError(ts.T(), err) + + err = json.Unmarshal(body, &output) + require.NoError(ts.T(), err, "Unmarshal should not fail") + require.True(ts.T(), output.Success, "Expected success on retry") + + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 3919cb781..b2d429dce 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -243,7 +243,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output := hooks.MFAVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.MFAVerificationAttempt.URI) if err != nil { return err } diff --git a/internal/api/otp.go b/internal/api/otp.go index 99b7bae32..d0e3d6f18 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -195,7 +195,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(err) } - mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) + mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } diff --git a/internal/api/phone.go b/internal/api/phone.go index f85caa6fd..4690ed5a3 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -2,6 +2,8 @@ package api import ( "bytes" + "github.com/supabase/auth/internal/hooks" + "net/http" "regexp" "strings" "text/template" @@ -40,7 +42,7 @@ func formatPhoneNumber(phone string) string { } // sendPhoneConfirmation sends an otp to the user's phone number -func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { +func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { config := a.config var token *string @@ -91,10 +93,23 @@ func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, p if err != nil { return "", err } - - messageID, err = smsProvider.SendMessage(phone, message, channel, otp) - if err != nil { - return messageID, err + if config.Hook.CustomSMSProvider.Enabled { + input := hooks.CustomSMSProviderInput{ + UserID: user.ID, + Phone: user.Phone.String(), + OTP: otp, + } + output := hooks.CustomSMSProviderOutput{} + err := a.invokeHTTPHook(r, &input, &output, config.Hook.CustomSMSProvider.URI) + if err != nil { + return "", err + } + } else { + + messageID, err = smsProvider.SendMessage(phone, message, channel, otp) + if err != nil { + return messageID, err + } } } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 09810e288..d2b9bb9e6 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -72,6 +72,8 @@ func (ts *PhoneTestSuite) TestFormatPhoneNumber() { func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) cases := []struct { desc string otpType string @@ -111,7 +113,7 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { ts.Run(c.desc, func() { provider := &TestSmsProvider{} - _, err = ts.API.sendPhoneConfirmation(ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) + _, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) require.Equal(ts.T(), c.expected, err) u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index 84b080070..8143320aa 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -49,7 +49,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Failed to get SMS provider").WithInternalError(terr) } - mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) + mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { return err } diff --git a/internal/api/resend.go b/internal/api/resend.go index fdad38c43..0ef80008d 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -134,7 +134,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } @@ -146,7 +146,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } diff --git a/internal/api/signup.go b/internal/api/signup.go index 5c7e588b8..e0517f6bc 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -278,7 +278,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } diff --git a/internal/api/token.go b/internal/api/token.go index df0292711..cee79be11 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -164,7 +164,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri Valid: isValidPassword, } output := hooks.PasswordVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.PasswordVerificationAttempt.URI) if err != nil { return err } @@ -339,7 +339,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u output := hooks.CustomAccessTokenOutput{} - err := a.invokeHook(ctx, tx, &input, &output) + err := a.invokePostgresHook(ctx, tx, &input, &output, config.Hook.CustomAccessToken.URI) if err != nil { return "", 0, err } diff --git a/internal/api/user.go b/internal/api/user.go index 9fe0dcef8..a9b3a2b89 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -227,7 +227,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Error finding SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) } } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 31fb8b22d..74167dea2 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -501,8 +501,14 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { switch strings.ToLower(u.Scheme) { case "pg-functions": return validatePostgresPath(u) + case "http": + hostname := u.Hostname() + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" { + return validateHTTPHookSecrets(e.HTTPHookSecrets) + } + return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") case "https": - return validateHTTPSHookSecrets(e.HTTPHookSecrets) + return validateHTTPHookSecrets(e.HTTPHookSecrets) default: return fmt.Errorf("only postgres hooks and HTTPS functions are supported at the moment") } @@ -530,7 +536,7 @@ func isValidSecretFormat(secret string) bool { return symmetricSecretFormat.MatchString(secret) || asymmetricSecretFormat.MatchString(secret) } -func validateHTTPSHookSecrets(secrets []string) error { +func validateHTTPHookSecrets(secrets []string) error { for _, secret := range secrets { if !isValidSecretFormat(secret) { return fmt.Errorf("invalid secret format") @@ -540,9 +546,6 @@ func validateHTTPSHookSecrets(secrets []string) error { } func (e *ExtensibilityPointConfiguration) PopulateExtensibilityPoint() error { - if err := e.ValidateExtensibilityPoint(); err != nil { - return err - } u, err := url.Parse(e.URI) if err != nil { return err diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index f881857bf..4d6eab003 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -161,8 +161,10 @@ func TestValidateExtensibilityPointURI(t *testing.T) { {desc: "Valid Postgres URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectError: false}, {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectError: false}, {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectError: false}, + {desc: "Valid HTTP URI", uri: "http://localhost/functions/v1/custom-sms-sender", expectError: false}, // Negative test cases + {desc: "Invalid HTTP URI", uri: "http://asdfgggg.website.co/functions/v1/custom-sms-sender", expectError: true}, {desc: "Invalid HTTPS URI (HTTP)", uri: "http://asdfgggqqwwerty.supabase.co/functions/v1/custom-sms-sender", expectError: true}, {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectError: true}, {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectError: true}, diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index e8063ab1e..590d1ba4d 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -9,6 +9,11 @@ import ( "math" "math/big" "strconv" + "strings" + "time" + + "github.com/gofrs/uuid" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "github.com/pkg/errors" ) @@ -41,3 +46,26 @@ func GenerateOtp(digits int) (string, error) { func GenerateTokenHash(emailOrPhone, otp string) string { return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) } + +func GenerateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index bd3163085..39a226ce7 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -22,6 +22,10 @@ const ( HookRejection = "reject" ) +type HTTPHookInput interface { + IsHTTPHook() +} + type HookOutput interface { IsError() bool Error() string @@ -135,6 +139,17 @@ type CustomAccessTokenOutput struct { HookError AuthHookError `json:"error,omitempty"` } +type CustomSMSProviderInput struct { + UserID uuid.UUID `json:"user_id"` + Phone string `json:"phone"` + OTP string `json:"otp"` +} + +type CustomSMSProviderOutput struct { + Success bool `json:"success"` + HookError AuthHookError `json:"error,omitempty"` +} + func (mf *MFAVerificationAttemptOutput) IsError() bool { return mf.HookError.Message != "" } @@ -159,6 +174,18 @@ func (ca *CustomAccessTokenOutput) Error() string { return ca.HookError.Message } +func (cs *CustomSMSProviderOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *CustomSMSProviderOutput) Error() string { + return cs.HookError.Message +} + +func (cs *CustomSMSProviderOutput) IsHTTPHook() bool { + return true +} + type AuthHookError struct { HTTPCode int `json:"http_code,omitempty"` Message string `json:"message,omitempty"` From 43df85470f7a896ff10337d07aeefda62717bf8a Mon Sep 17 00:00:00 2001 From: joel Date: Sun, 17 Mar 2024 23:30:09 +0800 Subject: [PATCH 02/11] fix: run gofmt --- internal/api/hooks.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 773041a80..256667db0 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -30,7 +30,6 @@ const ( HTTPHookBackoffDuration = 2 * time.Second ) - func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { db := a.db.WithContext(ctx) From 9b37fb9ca87e9b61436c72369914cd8ac7b21a00 Mon Sep 17 00:00:00 2001 From: joel Date: Tue, 19 Mar 2024 13:08:43 +0800 Subject: [PATCH 03/11] fix: apply suggestions --- internal/api/errors.go | 4 ---- internal/api/hooks.go | 39 ++++++--------------------------------- 2 files changed, 6 insertions(+), 37 deletions(-) diff --git a/internal/api/errors.go b/internal/api/errors.go index 216c9944e..cc6ba877b 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -96,10 +96,6 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } -func gatewayTimeoutError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusGatewayTimeout, errorCode, fmtString, args...) -} - // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { HTTPStatus int `json:"code"` // do not rename the JSON tags! diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 256667db0..ce9fa7ccb 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "net/http/httptrace" "strings" "time" @@ -28,6 +27,7 @@ const ( DefaultHTTPHookTimeout = 5 * time.Second DefaultHTTPHookRetries = 3 HTTPHookBackoffDuration = 2 * time.Second + PayloadLimit = 20 * 1024 // 20KB ) func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { @@ -77,8 +77,7 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name func readBodyWithLimit(rsp *http.Response) ([]byte, error) { defer rsp.Body.Close() - const limit = 20 * 1024 // 20KB - limitedReader := io.LimitedReader{R: rsp.Body, N: limit} + limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} body, err := io.ReadAll(&limitedReader) if err != nil { @@ -116,7 +115,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon for i := 0; i < DefaultHTTPHookRetries; i++ { hookLog.Infof("invocation attempt: %d", i) if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout { - return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout") + return []byte{}, unprocessableEntityError(ErrorHookTimeout, "failed to reach hook within timeout") } msgID := uuid.Must(uuid.NewV4()) currentTime := time.Now() @@ -127,7 +126,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) if err != nil { - return nil, internalServerError("Failed to make request object").WithInternalError(err) + panic("Failed to make requst object") } req.Header.Set("Content-Type", "application/json") @@ -135,21 +134,14 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) - watcher, req := watchForConnection(req) rsp, err := client.Do(req) - if err != nil { - if terr, ok := err.(net.Error); ok && terr.Timeout() { + if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) time.Sleep(HTTPHookBackoffDuration) continue - } else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 { - hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err) - time.Sleep(HTTPHookBackoffDuration) - continue } else if i == DefaultHTTPHookRetries-1 { - return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval") - + return nil, unprocessableEntityError(ErrorHookTimeout, "Failed to reach hook within allotted interval") } else { return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } @@ -183,24 +175,6 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon return nil, internalServerError("error executing hook") } -func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) { - w := new(connectionWatcher) - t := &httptrace.ClientTrace{ - GotConn: w.GotConn, - } - - req = req.WithContext(httptrace.WithClientTrace(req.Context(), t)) - return w, req -} - -type connectionWatcher struct { - gotConn bool -} - -func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { - c.gotConn = true -} - func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error { switch input.(type) { case *hooks.CustomSMSProviderInput: @@ -221,7 +195,6 @@ func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) if err := json.Unmarshal(response, hookOutput); err != nil { return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err) } - fmt.Printf("%v", hookOutput) default: panic("unknown HTTP hook type") From bcfe761074a386e8c79fa77b3377871653786259 Mon Sep 17 00:00:00 2001 From: joel Date: Tue, 19 Mar 2024 16:39:29 +0800 Subject: [PATCH 04/11] fix: remove readBodyWithLimit --- internal/api/errorcodes.go | 3 ++- internal/api/helpers_test.go | 2 +- internal/api/hooks.go | 39 ++++++++++++---------------------- internal/api/hooks_test.go | 25 +++++++++++++++------- internal/conf/configuration.go | 2 +- 5 files changed, 35 insertions(+), 36 deletions(-) diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index 8af111bde..be8eba70f 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -74,5 +74,6 @@ const ( ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" - ErrorHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" ) diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index 15f9ce4d6..84a4846e6 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -37,7 +37,7 @@ func TestIsValidCodeChallenge(t *testing.T) { } } -func TestIsValidPKCEParmas(t *testing.T) { +func TestIsValidPKCEParams(t *testing.T) { cases := []struct { challengeMethod string challenge string diff --git a/internal/api/hooks.go b/internal/api/hooks.go index ce9fa7ccb..af44d09ca 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "strconv" "strings" "time" @@ -74,28 +75,6 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name return response, nil } -func readBodyWithLimit(rsp *http.Response) ([]byte, error) { - defer rsp.Body.Close() - - limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} - - body, err := io.ReadAll(&limitedReader) - if err != nil { - return nil, err - } - - if limitedReader.N <= 0 { - // Attempt to read one more byte to check if we're exactly at the limit or over - _, err := rsp.Body.Read(make([]byte, 1)) - if err == nil { - // If we could read more, then the payload was too large - return nil, fmt.Errorf("payload too large") - } - } - - return body, nil -} - func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { client := http.Client{ Timeout: DefaultHTTPHookTimeout, @@ -115,7 +94,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon for i := 0; i < DefaultHTTPHookRetries; i++ { hookLog.Infof("invocation attempt: %d", i) if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout { - return []byte{}, unprocessableEntityError(ErrorHookTimeout, "failed to reach hook within timeout") + return []byte{}, unprocessableEntityError(ErrorCodeHookTimeout, "failed to reach hook within timeout") } msgID := uuid.Must(uuid.NewV4()) currentTime := time.Now() @@ -141,7 +120,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon time.Sleep(HTTPHookBackoffDuration) continue } else if i == DefaultHTTPHookRetries-1 { - return nil, unprocessableEntityError(ErrorHookTimeout, "Failed to reach hook within allotted interval") + return nil, unprocessableEntityError(ErrorCodeHookTimeout, "Failed to reach hook within allotted interval") } else { return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } @@ -152,7 +131,17 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon if rsp.Body == nil { return nil, nil } - body, err := readBodyWithLimit(rsp) + contentLengthEntry := rsp.Header.Get("content-length") + contentLength, err := strconv.Atoi(contentLengthEntry) + if err != nil { + return nil, err + } + if contentLength >= PayloadLimit { + return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, "payload exceeded size limit") + } + defer rsp.Body.Close() + limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} + body, err := io.ReadAll(&limitedReader) if err != nil { return nil, err } diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index 38ed6d49d..e79860b53 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -70,18 +70,27 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { }, { description: "Too many requests without retry header should not retry", - status: http.StatusGatewayTimeout, + status: http.StatusUnprocessableEntity, expectError: true, }, } for _, tc := range testCases { ts.Run(tc.description, func() { - gock.New(ts.Config.Hook.CustomSMSProvider.URI). - Post("/"). - MatchType("json"). - Reply(tc.status). - JSON(tc.mockResponse) + if tc.status == http.StatusOK { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse).SetHeader("content-length", "21") + } else { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse) + + } var output hooks.CustomSMSProviderOutput req, _ := http.NewRequest("POST", ts.Config.Hook.CustomSMSProvider.URI, nil) @@ -117,14 +126,14 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { Post("/"). MatchType("json"). Reply(http.StatusTooManyRequests). - SetHeader("retry-after", "true") + SetHeader("retry-after", "true").SetHeader("content-length", "21") // Simulate an additional response for the retry attempt gock.New(testURL). Post("/"). MatchType("json"). Reply(http.StatusOK). - JSON(successOutput) + JSON(successOutput).SetHeader("content-length", "21") var output hooks.CustomSMSProviderOutput diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 74167dea2..630c73798 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -503,7 +503,7 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { return validatePostgresPath(u) case "http": hostname := u.Hostname() - if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" { + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "host.docker.internal" { return validateHTTPHookSecrets(e.HTTPHookSecrets) } return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") From 730052a97683262f68d97b5c8c8cc41da79eedf9 Mon Sep 17 00:00:00 2001 From: joel Date: Tue, 19 Mar 2024 17:36:01 +0800 Subject: [PATCH 05/11] fix: pass down context --- internal/api/hooks.go | 14 +++++++------- internal/api/hooks_test.go | 6 ++++-- internal/api/otp.go | 2 +- internal/api/phone.go | 5 +++-- internal/api/phone_test.go | 4 +++- internal/api/reauthenticate.go | 2 +- internal/api/resend.go | 4 ++-- internal/api/signup.go | 2 +- internal/api/user.go | 2 +- 9 files changed, 23 insertions(+), 18 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index af44d09ca..0ec26d414 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -75,10 +75,14 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name return response, nil } -func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { +func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { client := http.Client{ Timeout: DefaultHTTPHookTimeout, } + // TODO: Figure out what to do with ctx + _, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + defer cancel() + log := observability.GetLogEntry(r) requestURL := hookConfig.URI hookLog := log.WithFields(logrus.Fields{ @@ -90,12 +94,8 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon if err != nil { return nil, err } - start := time.Now() for i := 0; i < DefaultHTTPHookRetries; i++ { hookLog.Infof("invocation attempt: %d", i) - if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout { - return []byte{}, unprocessableEntityError(ErrorCodeHookTimeout, "failed to reach hook within timeout") - } msgID := uuid.Must(uuid.NewV4()) currentTime := time.Now() signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) @@ -164,7 +164,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon return nil, internalServerError("error executing hook") } -func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error { +func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any, hookURI string) error { switch input.(type) { case *hooks.CustomSMSProviderInput: hookOutput, ok := output.(*hooks.CustomSMSProviderOutput) @@ -174,7 +174,7 @@ func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) var response []byte var err error - if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil { + if response, err = a.runHTTPHook(ctx, r, a.config.Hook.CustomSMSProvider, input, output); err != nil { return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err) } if err != nil { diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index e79860b53..40a9da7a4 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -94,7 +94,8 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { var output hooks.CustomSMSProviderOutput req, _ := http.NewRequest("POST", ts.Config.Hook.CustomSMSProvider.URI, nil) - body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + ctx := req.Context() + body, err := ts.API.runHTTPHook(ctx, req, ts.Config.Hook.CustomSMSProvider, &input, &output) if !tc.expectError { require.NoError(ts.T(), err) @@ -140,8 +141,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { // Simulate the original HTTP request which triggered the hook req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) require.NoError(ts.T(), err) + ctx := req.Context() - body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + body, err := ts.API.runHTTPHook(ctx, req, ts.Config.Hook.CustomSMSProvider, &input, &output) require.NoError(ts.T(), err) err = json.Unmarshal(body, &output) diff --git a/internal/api/otp.go b/internal/api/otp.go index d0e3d6f18..a2ea95429 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -195,7 +195,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(err) } - mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) + mID, serr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } diff --git a/internal/api/phone.go b/internal/api/phone.go index 4690ed5a3..c018b09a8 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "github.com/supabase/auth/internal/hooks" "net/http" "regexp" @@ -42,7 +43,7 @@ func formatPhoneNumber(phone string) string { } // sendPhoneConfirmation sends an otp to the user's phone number -func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { +func (a *API) sendPhoneConfirmation(ctx context.Context, r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { config := a.config var token *string @@ -100,7 +101,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use OTP: otp, } output := hooks.CustomSMSProviderOutput{} - err := a.invokeHTTPHook(r, &input, &output, config.Hook.CustomSMSProvider.URI) + err := a.invokeHTTPHook(ctx, r, &input, &output, config.Hook.CustomSMSProvider.URI) if err != nil { return "", err } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index d2b9bb9e6..06c1bb1bc 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -113,7 +113,9 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { ts.Run(c.desc, func() { provider := &TestSmsProvider{} - _, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) + ctx := req.Context() + + _, err = ts.API.sendPhoneConfirmation(ctx, req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) require.Equal(ts.T(), c.expected, err) u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index 8143320aa..1fa8de15e 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -49,7 +49,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Failed to get SMS provider").WithInternalError(terr) } - mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) + mID, err := a.sendPhoneConfirmation(ctx, r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { return err } diff --git a/internal/api/resend.go b/internal/api/resend.go index 0ef80008d..3f48b4cd2 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -134,7 +134,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } @@ -146,7 +146,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(ctx, r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } diff --git a/internal/api/signup.go b/internal/api/signup.go index e0517f6bc..c476cba05 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -278,7 +278,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } diff --git a/internal/api/user.go b/internal/api/user.go index a9b3a2b89..7be379775 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -227,7 +227,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Error finding SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) } } From eba4cd3a6869bfd48edbb4416b6c272e1dba83f5 Mon Sep 17 00:00:00 2001 From: joel Date: Wed, 20 Mar 2024 10:15:40 +0800 Subject: [PATCH 06/11] fix: use content length header --- internal/api/errorcodes.go | 1 + internal/api/hooks.go | 20 +++++++++----------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index be8eba70f..30762f9f2 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -76,4 +76,5 @@ const ( ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" ErrorCodeHookTimeout ErrorCode = "hook_timeout" ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" + ErrorCodeHookPayloadUnknown ErrorCode = "hook_payload_unknown" ) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 0ec26d414..65844efb4 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "strconv" "strings" "time" @@ -28,7 +27,7 @@ const ( DefaultHTTPHookTimeout = 5 * time.Second DefaultHTTPHookRetries = 3 HTTPHookBackoffDuration = 2 * time.Second - PayloadLimit = 20 * 1024 // 20KB + PayloadLimit = 200 * 1024 // 200KB ) func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { @@ -79,8 +78,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. client := http.Client{ Timeout: DefaultHTTPHookTimeout, } - // TODO: Figure out what to do with ctx - _, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) defer cancel() log := observability.GetLogEntry(r) @@ -103,9 +101,9 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return nil, err } - req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) if err != nil { - panic("Failed to make requst object") + panic("Failed to make request object") } req.Header.Set("Content-Type", "application/json") @@ -131,10 +129,10 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. if rsp.Body == nil { return nil, nil } - contentLengthEntry := rsp.Header.Get("content-length") - contentLength, err := strconv.Atoi(contentLengthEntry) - if err != nil { - return nil, err + contentLength := rsp.ContentLength + // unknown content length is handled for by nil check on Body + if contentLength == -1 { + return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknown, "payload size not known") } if contentLength >= PayloadLimit { return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, "payload exceeded size limit") @@ -156,7 +154,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. case http.StatusBadRequest: return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook") case http.StatusUnauthorized: - return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token") + return []byte{}, forbiddenError(ErrorCodeNoAuthorization, "Hook requires authorization token") default: return []byte{}, internalServerError("Error executing Hook") } From 3c3e2b0eefa29fb4fe482d2dce49de2b016c8056 Mon Sep 17 00:00:00 2001 From: joel Date: Wed, 20 Mar 2024 22:20:49 +0800 Subject: [PATCH 07/11] fix: switch to identity encoding --- internal/api/hooks.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 65844efb4..814affccd 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -110,6 +110,8 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. req.Header.Set("webhook-id", msgID.String()) req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + // By default, Go Client sets encoding to gzip, which does not carry a content length header. + req.Header.Set("Accept-Encoding", "identity") rsp, err := client.Do(req) if err != nil { @@ -123,14 +125,13 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } } - switch rsp.StatusCode { case http.StatusOK, http.StatusNoContent, http.StatusAccepted: if rsp.Body == nil { return nil, nil } - contentLength := rsp.ContentLength // unknown content length is handled for by nil check on Body + contentLength := rsp.ContentLength if contentLength == -1 { return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknown, "payload size not known") } @@ -146,7 +147,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return body, nil case http.StatusTooManyRequests, http.StatusServiceUnavailable: retryAfterHeader := rsp.Header.Get("retry-after") - // Check for truthy values to allow for flexibility to swtich to time duration + // Check for truthy values to allow for flexibility to switch to time duration if retryAfterHeader != "" { continue } From 3da7b5857a6141b13721621bf12e156fce357c8a Mon Sep 17 00:00:00 2001 From: joel Date: Thu, 21 Mar 2024 11:30:48 +0800 Subject: [PATCH 08/11] fix: move defer to right after error --- internal/api/hooks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 814affccd..7bd15b4e9 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -125,6 +125,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } } + defer rsp.Body.Close() switch rsp.StatusCode { case http.StatusOK, http.StatusNoContent, http.StatusAccepted: if rsp.Body == nil { @@ -138,7 +139,6 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. if contentLength >= PayloadLimit { return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, "payload exceeded size limit") } - defer rsp.Body.Close() limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} body, err := io.ReadAll(&limitedReader) if err != nil { From 568ac5ee60b05097cb60cd7e3a94f8225fb0037d Mon Sep 17 00:00:00 2001 From: joel Date: Thu, 21 Mar 2024 14:24:00 +0800 Subject: [PATCH 09/11] fix: split error codes --- internal/api/errorcodes.go | 3 ++- internal/api/hooks.go | 13 ++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index 30762f9f2..20202320f 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -75,6 +75,7 @@ const ( ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" ErrorCodeHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" - ErrorCodeHookPayloadUnknown ErrorCode = "hook_payload_unknown" + ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size" ) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 7bd15b4e9..d9a328c91 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -114,13 +115,16 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. req.Header.Set("Accept-Encoding", "identity") rsp, err := client.Do(req) - if err != nil { + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) + + } else if err != nil { if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) time.Sleep(HTTPHookBackoffDuration) continue } else if i == DefaultHTTPHookRetries-1 { - return nil, unprocessableEntityError(ErrorCodeHookTimeout, "Failed to reach hook within allotted interval") + return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") } else { return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } @@ -131,13 +135,12 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. if rsp.Body == nil { return nil, nil } - // unknown content length is handled for by nil check on Body contentLength := rsp.ContentLength if contentLength == -1 { - return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknown, "payload size not known") + return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknownSize, "Payload size not known") } if contentLength >= PayloadLimit { - return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, "payload exceeded size limit") + return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size is: %d bytes exceeded size limit of %d bytes", contentLength, PayloadLimit)) } limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} body, err := io.ReadAll(&limitedReader) From 4a41d1ec65fe82821452e142251fc71a3fe42137 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Wed, 27 Mar 2024 13:58:29 +0800 Subject: [PATCH 10/11] Apply suggestions from code review Co-authored-by: Stojan Dimitrovski --- internal/api/hooks.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index d9a328c91..66d6e788d 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -129,7 +129,9 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) } } + defer rsp.Body.Close() + switch rsp.StatusCode { case http.StatusOK, http.StatusNoContent, http.StatusAccepted: if rsp.Body == nil { @@ -142,7 +144,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. if contentLength >= PayloadLimit { return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size is: %d bytes exceeded size limit of %d bytes", contentLength, PayloadLimit)) } - limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} + limitedReader := io.LimitedReader{R: rsp.Body, N: contentLength} body, err := io.ReadAll(&limitedReader) if err != nil { return nil, err @@ -154,7 +156,7 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. if retryAfterHeader != "" { continue } - return []byte{}, internalServerError("Service currently unavailable") + return nil, internalServerError("Service currently unavailable due to hook") case http.StatusBadRequest: return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook") case http.StatusUnauthorized: From 5648f32cef46338a1cb495699309677848c3f1c6 Mon Sep 17 00:00:00 2001 From: joel Date: Wed, 27 Mar 2024 14:21:06 +0800 Subject: [PATCH 11/11] fix: update error messages --- internal/api/hooks.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 66d6e788d..655241c38 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -94,7 +94,11 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. return nil, err } for i := 0; i < DefaultHTTPHookRetries; i++ { - hookLog.Infof("invocation attempt: %d", i) + if i == 0 { + hookLog.Debugf("invocation attempt: %d", i) + } else { + hookLog.Infof("invocation attempt: %d", i) + } msgID := uuid.Must(uuid.NewV4()) currentTime := time.Now() signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) @@ -158,14 +162,14 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf. } return nil, internalServerError("Service currently unavailable due to hook") case http.StatusBadRequest: - return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook") + return nil, internalServerError("Invalid payload sent to hook") case http.StatusUnauthorized: - return []byte{}, forbiddenError(ErrorCodeNoAuthorization, "Hook requires authorization token") + return nil, internalServerError("Hook requires authorization token") default: - return []byte{}, internalServerError("Error executing Hook") + return nil, internalServerError("Error executing Hook") } } - return nil, internalServerError("error executing hook") + return nil, nil } func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any, hookURI string) error {