Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom sms hook #1474

Merged
merged 12 commits into from
Mar 27, 2024
Merged
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ require (
github.com/fatih/structs v1.1.0
github.com/gobuffalo/pop/v6 v6.1.1
github.com/jackc/pgx/v4 v4.18.2
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747
github.com/xeipuuv/gojsonschema v1.2.0
Expand Down Expand Up @@ -146,4 +147,6 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.21
go 1.21.0

toolchain go1.21.6
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
Expand Down
4 changes: 4 additions & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ const (
ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit"
ErrorBadCodeVerifier ErrorCode = "bad_code_verifier"
ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled"
ErrorCodeHookTimeout ErrorCode = "hook_timeout"
ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry"
ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit"
ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size"
)
2 changes: 1 addition & 1 deletion internal/api/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestIsValidCodeChallenge(t *testing.T) {
}
}

func TestIsValidPKCEParmas(t *testing.T) {
func TestIsValidPKCEParams(t *testing.T) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
cases := []struct {
challengeMethod string
challenge string
Expand Down
159 changes: 152 additions & 7 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
package api

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/observability"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/hooks"

"github.com/supabase/auth/internal/storage"
)

const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
PayloadLimit = 200 * 1024 // 200KB
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

Expand Down Expand Up @@ -55,20 +75,145 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return response, nil
}

// invokeHook invokes the hook code. tx can be nil, in which case a new
func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout)
defer cancel()

log := observability.GetLogEntry(r)
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})

inputPayload, err := json.Marshal(input)
if err != nil {
return nil, err
}
for i := 0; i < DefaultHTTPHookRetries; i++ {
if i == 0 {
hookLog.Debugf("invocation attempt: %d", i)
} else {
hookLog.Infof("invocation attempt: %d", i)
}
msgID := uuid.Must(uuid.NewV4())
currentTime := time.Now()
signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
J0 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
panic("Failed to make request object")
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
// By default, Go Client sets encoding to gzip, which does not carry a content length header.
req.Header.Set("Accept-Encoding", "identity")
J0 marked this conversation as resolved.
Show resolved Hide resolved

rsp, err := client.Do(req)
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds()))

} else if err != nil {
if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 {
J0 marked this conversation as resolved.
Show resolved Hide resolved
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries")
} else {
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
}
}
J0 marked this conversation as resolved.
Show resolved Hide resolved

defer rsp.Body.Close()
J0 marked this conversation as resolved.
Show resolved Hide resolved

switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
if rsp.Body == nil {
return nil, nil
}
contentLength := rsp.ContentLength
J0 marked this conversation as resolved.
Show resolved Hide resolved
if contentLength == -1 {
return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknownSize, "Payload size not known")
}
if contentLength >= PayloadLimit {
return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size is: %d bytes exceeded size limit of %d bytes", contentLength, PayloadLimit))
}
limitedReader := io.LimitedReader{R: rsp.Body, N: contentLength}
body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
// Check for truthy values to allow for flexibility to switch to time duration
if retryAfterHeader != "" {
continue
}
return nil, internalServerError("Service currently unavailable due to hook")
case http.StatusBadRequest:
return nil, internalServerError("Invalid payload sent to hook")
case http.StatusUnauthorized:
return nil, internalServerError("Hook requires authorization token")
default:
return nil, internalServerError("Error executing Hook")
}
}
return nil, nil
}

func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any, hookURI string) error {
switch input.(type) {
case *hooks.CustomSMSProviderInput:
hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
if !ok {
panic("output should be *hooks.CustomSMSProviderOutput")
}
var response []byte
var err error

if response, err = a.runHTTPHook(ctx, r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
}
if err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
}

default:
panic("unknown HTTP hook type")
}
return nil
}

// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaciton, as pool-exhaustion deadlocks are very easy to
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error {
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
config := a.config
// Switch based on hook type
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

Expand All @@ -94,7 +239,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

Expand All @@ -120,7 +265,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

Expand Down Expand Up @@ -155,6 +300,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
return nil

default:
panic("unknown hook input type")
panic("unknown Postgres hook input type")
}
}
Loading
Loading