diff --git a/go.mod b/go.mod index 0baa8ffc4..9a7ef3f74 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/fatih/structs v1.1.0 github.com/gobuffalo/pop/v6 v6.1.1 github.com/jackc/pgx/v4 v4.18.2 + github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747 github.com/xeipuuv/gojsonschema v1.2.0 @@ -146,4 +147,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.21 +go 1.21.0 + +toolchain go1.21.6 diff --git a/go.sum b/go.sum index 8783b5991..6b5c469ea 100644 --- a/go.sum +++ b/go.sum @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index 45dec0dd7..20202320f 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -74,4 +74,8 @@ const ( ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" + ErrorCodeHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" + ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" + ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size" ) diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index 15f9ce4d6..84a4846e6 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -37,7 +37,7 @@ func TestIsValidCodeChallenge(t *testing.T) { } } -func TestIsValidPKCEParmas(t *testing.T) { +func TestIsValidPKCEParams(t *testing.T) { cases := []struct { challengeMethod string challenge string diff --git a/internal/api/hooks.go b/internal/api/hooks.go index ea8b337df..655241c38 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -1,16 +1,36 @@ package api import ( + "bytes" "context" "encoding/json" + "errors" "fmt" + "io" + "net" "net/http" + "strings" + "time" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/observability" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + + "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/storage" ) +const ( + DefaultHTTPHookTimeout = 5 * time.Second + DefaultHTTPHookRetries = 3 + HTTPHookBackoffDuration = 2 * time.Second + PayloadLimit = 200 * 1024 // 200KB +) + func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { db := a.db.WithContext(ctx) @@ -55,12 +75,137 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name return response, nil } -// invokeHook invokes the hook code. tx can be nil, in which case a new +func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + defer cancel() + + log := observability.GetLogEntry(r) + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + for i := 0; i < DefaultHTTPHookRetries; i++ { + if i == 0 { + hookLog.Debugf("invocation attempt: %d", i) + } else { + hookLog.Infof("invocation attempt: %d", i) + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + panic("Failed to make request object") + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + // By default, Go Client sets encoding to gzip, which does not carry a content length header. + req.Header.Set("Accept-Encoding", "identity") + + rsp, err := client.Do(req) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) + + } else if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") + } else { + return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + defer rsp.Body.Close() + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + if rsp.Body == nil { + return nil, nil + } + contentLength := rsp.ContentLength + if contentLength == -1 { + return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknownSize, "Payload size not known") + } + if contentLength >= PayloadLimit { + return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size is: %d bytes exceeded size limit of %d bytes", contentLength, PayloadLimit)) + } + limitedReader := io.LimitedReader{R: rsp.Body, N: contentLength} + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to switch to time duration + if retryAfterHeader != "" { + continue + } + return nil, internalServerError("Service currently unavailable due to hook") + case http.StatusBadRequest: + return nil, internalServerError("Invalid payload sent to hook") + case http.StatusUnauthorized: + return nil, internalServerError("Hook requires authorization token") + default: + return nil, internalServerError("Error executing Hook") + } + } + return nil, nil +} + +func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any, hookURI string) error { + switch input.(type) { + case *hooks.CustomSMSProviderInput: + hookOutput, ok := output.(*hooks.CustomSMSProviderOutput) + if !ok { + panic("output should be *hooks.CustomSMSProviderOutput") + } + var response []byte + var err error + + if response, err = a.runHTTPHook(ctx, r, a.config.Hook.CustomSMSProvider, input, output); err != nil { + return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err) + } + if err != nil { + return err + } + + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err) + } + + default: + panic("unknown HTTP hook type") + } + return nil +} + +// invokePostgresHook invokes the hook code. tx can be nil, in which case a new // transaction is opened. If calling invokeHook within a transaction, always -// pass the current transaciton, as pool-exhaustion deadlocks are very easy to +// pass the current transaction, as pool-exhaustion deadlocks are very easy to // trigger. -func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error { +func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error { config := a.config + // Switch based on hook type switch input.(type) { case *hooks.MFAVerificationAttemptInput: hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) @@ -68,7 +213,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.MFAVerificationAttemptOutput") } - if _, err := a.runPostgresHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) } @@ -94,7 +239,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.PasswordVerificationAttemptOutput") } - if _, err := a.runPostgresHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking password verification hook.").WithInternalError(err) } @@ -120,7 +265,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.CustomAccessTokenOutput") } - if _, err := a.runPostgresHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil { return internalServerError("Error invoking access token hook.").WithInternalError(err) } @@ -155,6 +300,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out return nil default: - panic("unknown hook input type") + panic("unknown Postgres hook input type") } } diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go new file mode 100644 index 000000000..40a9da7a4 --- /dev/null +++ b/internal/api/hooks_test.go @@ -0,0 +1,155 @@ +package api + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type HooksTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestHooks(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &HooksTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *HooksTestSuite) TestRunHTTPHook() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + testCases := []struct { + description string + mockResponse interface{} + status int + expectError bool + }{ + { + description: "Successful Post request with delay", + mockResponse: successOutput, + status: http.StatusOK, + expectError: false, + }, + { + description: "Too many requests without retry header should not retry", + status: http.StatusUnprocessableEntity, + expectError: true, + }, + } + + for _, tc := range testCases { + ts.Run(tc.description, func() { + if tc.status == http.StatusOK { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse).SetHeader("content-length", "21") + } else { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse) + + } + + var output hooks.CustomSMSProviderOutput + req, _ := http.NewRequest("POST", ts.Config.Hook.CustomSMSProvider.URI, nil) + ctx := req.Context() + body, err := ts.API.runHTTPHook(ctx, req, ts.Config.Hook.CustomSMSProvider, &input, &output) + + if !tc.expectError { + require.NoError(ts.T(), err) + if body != nil { + require.NoError(ts.T(), json.Unmarshal(body, &output)) + require.True(ts.T(), output.Success) + } + } else { + require.Error(ts.T(), err) + } + require.True(ts.T(), gock.IsDone()) + }) + } +} + +func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusTooManyRequests). + SetHeader("retry-after", "true").SetHeader("content-length", "21") + + // Simulate an additional response for the retry attempt + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput).SetHeader("content-length", "21") + + var output hooks.CustomSMSProviderOutput + + // Simulate the original HTTP request which triggered the hook + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + ctx := req.Context() + + body, err := ts.API.runHTTPHook(ctx, req, ts.Config.Hook.CustomSMSProvider, &input, &output) + require.NoError(ts.T(), err) + + err = json.Unmarshal(body, &output) + require.NoError(ts.T(), err, "Unmarshal should not fail") + require.True(ts.T(), output.Success, "Expected success on retry") + + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 3919cb781..b2d429dce 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -243,7 +243,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output := hooks.MFAVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.MFAVerificationAttempt.URI) if err != nil { return err } diff --git a/internal/api/otp.go b/internal/api/otp.go index 99b7bae32..a2ea95429 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -195,7 +195,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(err) } - mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) + mID, serr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } diff --git a/internal/api/phone.go b/internal/api/phone.go index f85caa6fd..c018b09a8 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -2,6 +2,9 @@ package api import ( "bytes" + "context" + "github.com/supabase/auth/internal/hooks" + "net/http" "regexp" "strings" "text/template" @@ -40,7 +43,7 @@ func formatPhoneNumber(phone string) string { } // sendPhoneConfirmation sends an otp to the user's phone number -func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { +func (a *API) sendPhoneConfirmation(ctx context.Context, r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { config := a.config var token *string @@ -91,10 +94,23 @@ func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, p if err != nil { return "", err } - - messageID, err = smsProvider.SendMessage(phone, message, channel, otp) - if err != nil { - return messageID, err + if config.Hook.CustomSMSProvider.Enabled { + input := hooks.CustomSMSProviderInput{ + UserID: user.ID, + Phone: user.Phone.String(), + OTP: otp, + } + output := hooks.CustomSMSProviderOutput{} + err := a.invokeHTTPHook(ctx, r, &input, &output, config.Hook.CustomSMSProvider.URI) + if err != nil { + return "", err + } + } else { + + messageID, err = smsProvider.SendMessage(phone, message, channel, otp) + if err != nil { + return messageID, err + } } } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 09810e288..06c1bb1bc 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -72,6 +72,8 @@ func (ts *PhoneTestSuite) TestFormatPhoneNumber() { func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) cases := []struct { desc string otpType string @@ -111,7 +113,9 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { ts.Run(c.desc, func() { provider := &TestSmsProvider{} - _, err = ts.API.sendPhoneConfirmation(ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) + ctx := req.Context() + + _, err = ts.API.sendPhoneConfirmation(ctx, req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) require.Equal(ts.T(), c.expected, err) u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index 84b080070..1fa8de15e 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -49,7 +49,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Failed to get SMS provider").WithInternalError(terr) } - mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) + mID, err := a.sendPhoneConfirmation(ctx, r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { return err } diff --git a/internal/api/resend.go b/internal/api/resend.go index fdad38c43..3f48b4cd2 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -134,7 +134,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } @@ -146,7 +146,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(ctx, r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } diff --git a/internal/api/signup.go b/internal/api/signup.go index 5c7e588b8..c476cba05 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -278,7 +278,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } diff --git a/internal/api/token.go b/internal/api/token.go index df0292711..cee79be11 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -164,7 +164,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri Valid: isValidPassword, } output := hooks.PasswordVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.PasswordVerificationAttempt.URI) if err != nil { return err } @@ -339,7 +339,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u output := hooks.CustomAccessTokenOutput{} - err := a.invokeHook(ctx, tx, &input, &output) + err := a.invokePostgresHook(ctx, tx, &input, &output, config.Hook.CustomAccessToken.URI) if err != nil { return "", 0, err } diff --git a/internal/api/user.go b/internal/api/user.go index 9fe0dcef8..7be379775 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -227,7 +227,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Error finding SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(ctx, r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) } } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 6786c4cfa..630c73798 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -503,12 +503,12 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { return validatePostgresPath(u) case "http": hostname := u.Hostname() - if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" { - return validateHTTPSHookSecrets(e.HTTPHookSecrets) + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "host.docker.internal" { + return validateHTTPHookSecrets(e.HTTPHookSecrets) } return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") case "https": - return validateHTTPSHookSecrets(e.HTTPHookSecrets) + return validateHTTPHookSecrets(e.HTTPHookSecrets) default: return fmt.Errorf("only postgres hooks and HTTPS functions are supported at the moment") } @@ -536,7 +536,7 @@ func isValidSecretFormat(secret string) bool { return symmetricSecretFormat.MatchString(secret) || asymmetricSecretFormat.MatchString(secret) } -func validateHTTPSHookSecrets(secrets []string) error { +func validateHTTPHookSecrets(secrets []string) error { for _, secret := range secrets { if !isValidSecretFormat(secret) { return fmt.Errorf("invalid secret format") @@ -546,9 +546,6 @@ func validateHTTPSHookSecrets(secrets []string) error { } func (e *ExtensibilityPointConfiguration) PopulateExtensibilityPoint() error { - if err := e.ValidateExtensibilityPoint(); err != nil { - return err - } u, err := url.Parse(e.URI) if err != nil { return err diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index e8063ab1e..590d1ba4d 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -9,6 +9,11 @@ import ( "math" "math/big" "strconv" + "strings" + "time" + + "github.com/gofrs/uuid" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "github.com/pkg/errors" ) @@ -41,3 +46,26 @@ func GenerateOtp(digits int) (string, error) { func GenerateTokenHash(emailOrPhone, otp string) string { return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) } + +func GenerateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index bd3163085..39a226ce7 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -22,6 +22,10 @@ const ( HookRejection = "reject" ) +type HTTPHookInput interface { + IsHTTPHook() +} + type HookOutput interface { IsError() bool Error() string @@ -135,6 +139,17 @@ type CustomAccessTokenOutput struct { HookError AuthHookError `json:"error,omitempty"` } +type CustomSMSProviderInput struct { + UserID uuid.UUID `json:"user_id"` + Phone string `json:"phone"` + OTP string `json:"otp"` +} + +type CustomSMSProviderOutput struct { + Success bool `json:"success"` + HookError AuthHookError `json:"error,omitempty"` +} + func (mf *MFAVerificationAttemptOutput) IsError() bool { return mf.HookError.Message != "" } @@ -159,6 +174,18 @@ func (ca *CustomAccessTokenOutput) Error() string { return ca.HookError.Message } +func (cs *CustomSMSProviderOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *CustomSMSProviderOutput) Error() string { + return cs.HookError.Message +} + +func (cs *CustomSMSProviderOutput) IsHTTPHook() bool { + return true +} + type AuthHookError struct { HTTPCode int `json:"http_code,omitempty"` Message string `json:"message,omitempty"`