Skip to content

Commit

Permalink
feat: add hook log entry with run_hook action (#1684)
Browse files Browse the repository at this point in the history
Adds a log entry when hooks run. Also refactors the `invokeHook` API to
not require redundant parameters like the URI.
  • Loading branch information
hf authored Jul 29, 2024
1 parent 7de0cb3 commit 46491b8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 23 deletions.
50 changes: 34 additions & 16 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"mime"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -188,21 +187,17 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any, uri string) error {
func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error {
var err error
var response []byte
u, err := url.Parse(uri)
if err != nil {
return err
}

switch input.(type) {
case *hooks.SendSMSInput:
hookOutput, ok := output.(*hooks.SendSMSOutput)
if !ok {
panic("output should be *hooks.SendSMSOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output, u.Scheme); err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
Expand All @@ -226,7 +221,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu
if !ok {
panic("output should be *hooks.SendEmailOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output, u.Scheme); err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
Expand All @@ -252,7 +247,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output, u.Scheme); err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
Expand All @@ -279,7 +274,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output, u.Scheme); err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
Expand All @@ -306,7 +301,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output, u.Scheme); err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
Expand Down Expand Up @@ -345,20 +340,43 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu
return nil
}

func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any, scheme string) ([]byte, error) {
func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
ctx := r.Context()

logEntry := observability.GetLogEntry(r)
hookStart := time.Now()

var response []byte
var err error
switch strings.ToLower(scheme) {
case "http", "https":

switch {
case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"):
response, err = a.runHTTPHook(r, hookConfig, input)
case "pg-functions":
case strings.HasPrefix(hookConfig.URI, "pg-functions:"):
response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output)
default:
return nil, fmt.Errorf("unsupported protocol: %v only postgres hooks and HTTPS functions are supported at the moment", scheme)
return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI)
}

duration := time.Since(hookStart)

if err != nil {
logEntry.Entry.WithFields(logrus.Fields{
"action": "run_hook",
"hook": hookConfig.URI,
"success": false,
"duration": duration.Microseconds(),
}).WithError(err).Warn("Hook errored out")

return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err)
}

logEntry.Entry.WithFields(logrus.Fields{
"action": "run_hook",
"hook": hookConfig.URI,
"success": true,
"duration": duration.Microseconds(),
}).WithError(err).Info("Hook ran successfully")

return response, nil
}
4 changes: 2 additions & 2 deletions internal/api/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (ts *HooksTestSuite) TestInvokeHookIntegration() {
input: &hooks.SendEmailInput{},
output: &hooks.SendEmailOutput{},
uri: "ftp://example.com/path",
expectedError: errors.New("unsupported protocol: ftp only postgres hooks and HTTPS functions are supported at the moment"),
expectedError: errors.New("unsupported protocol: \"ftp://example.com/path\" only postgres hooks and HTTPS functions are supported at the moment"),
},
}

Expand All @@ -274,7 +274,7 @@ func (ts *HooksTestSuite) TestInvokeHookIntegration() {
require.NoError(ts.T(), ts.Config.Hook.SendEmail.PopulateExtensibilityPoint())

ts.Run(tc.description, func() {
err = ts.API.invokeHook(tc.conn, tc.request, tc.input, tc.output, tc.uri)
err = ts.API.invokeHook(tc.conn, tc.request, tc.input, tc.output)
if tc.expectedError != nil {
require.EqualError(ts.T(), err, tc.expectedError.Error())
} else {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
EmailData: emailData,
}
output := hooks.SendEmailOutput{}
return a.invokeHook(tx, r, &input, &output, a.config.Hook.SendEmail.URI)
return a.invokeHook(tx, r, &input, &output)
}

switch emailActionType {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
}

output := hooks.MFAVerificationAttemptOutput{}
err := a.invokeHook(nil, r, &input, &output, a.config.Hook.MFAVerificationAttempt.URI)
err := a.invokeHook(nil, r, &input, &output)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
},
}
output := hooks.SendSMSOutput{}
err := a.invokeHook(tx, r, &input, &output, a.config.Hook.SendSMS.URI)
err := a.invokeHook(tx, r, &input, &output)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
Valid: isValidPassword,
}
output := hooks.PasswordVerificationAttemptOutput{}
err := a.invokeHook(nil, r, &input, &output, a.config.Hook.PasswordVerificationAttempt.URI)
err := a.invokeHook(nil, r, &input, &output)
if err != nil {
return err
}
Expand Down Expand Up @@ -360,7 +360,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user

output := hooks.CustomAccessTokenOutput{}

err := a.invokeHook(tx, r, &input, &output, a.config.Hook.CustomAccessToken.URI)
err := a.invokeHook(tx, r, &input, &output)
if err != nil {
return "", 0, err
}
Expand Down

0 comments on commit 46491b8

Please sign in to comment.