diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index b798f9e39914..1327ea747de6 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -12,7 +12,9 @@ import ( "github.com/pkg/errors" "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.11.0" "go.opentelemetry.io/otel/trace" "github.com/ory/kratos/ui/node" @@ -30,7 +32,6 @@ import ( "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/x" - "github.com/ory/x/otelx" ) var ( @@ -254,22 +255,6 @@ func (e *WebHook) ExecuteSettingsPrePersistHook(_ http.ResponseWriter, req *http } func (e *WebHook) execute(ctx context.Context, data *templateContext) error { - span := trace.SpanFromContext(ctx) - attrs := map[string]string{ - "webhook.http.method": data.RequestMethod, - "webhook.http.url": data.RequestURL, - "webhook.http.headers": fmt.Sprintf("%#v", data.RequestHeaders), - } - - if data.Identity != nil { - attrs["webhook.identity.id"] = data.Identity.ID.String() - } else { - attrs["webhook.identity.id"] = "" - } - - span.SetAttributes(otelx.StringAttrs(attrs)...) - defer span.End() - builder, err := request.NewBuilder(e.conf, e.deps) if err != nil { return err @@ -282,34 +267,60 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error { return err } - errChan := make(chan error, 1) + attrs := semconv.HTTPClientAttributesFromHTTPRequest(req.Request) + if data.Identity != nil { + attrs = append(attrs, + attribute.String("webhook.identity.id", data.Identity.ID.String()), + attribute.String("webhook.identity.nid", data.Identity.NID.String()), + ) + } + var ( + httpClient = e.deps.HTTPClient(ctx) + async = gjson.GetBytes(e.conf, "response.ignore").Bool() + parseResponse = gjson.GetBytes(e.conf, "can_interrupt").Bool() + tracer = trace.SpanFromContext(ctx).TracerProvider().Tracer("kratos-webhooks") + cancel context.CancelFunc = func() {} + spanOpts = []trace.SpanStartOption{trace.WithAttributes(attrs...)} + errChan = make(chan error, 1) + ) + if async { + // dissociate the context from the one passed into this function + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Minute) + spanOpts = append(spanOpts, trace.WithNewRoot()) + } + ctx, span := tracer.Start(ctx, "Webhook", spanOpts...) e.deps.Logger().WithRequest(req.Request).Info("Dispatching webhook") t0 := time.Now() go func() { defer close(errChan) + defer cancel() + defer span.End() - resp, err := e.deps.HTTPClient(ctx).Do(req.WithContext(ctx)) + resp, err := httpClient.Do(req.WithContext(ctx)) if err != nil { + span.SetStatus(codes.Error, err.Error()) errChan <- errors.WithStack(err) return } defer resp.Body.Close() + span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...) if resp.StatusCode >= http.StatusBadRequest { - if gjson.GetBytes(e.conf, "can_interrupt").Bool() { + span.SetStatus(codes.Error, "HTTP status code >= 400") + if parseResponse { if err := parseWebhookResponse(resp); err != nil { + span.SetStatus(codes.Error, err.Error()) errChan <- err } } - errChan <- fmt.Errorf("web hook failed with status code %v", resp.StatusCode) - span.SetStatus(codes.Error, fmt.Sprintf("web hook failed with status code %v", resp.StatusCode)) + errChan <- fmt.Errorf("webhook failed with status code %v", resp.StatusCode) return } errChan <- nil }() - if gjson.GetBytes(e.conf, "response.ignore").Bool() { + if async { traceID, spanID := span.SpanContext().TraceID(), span.SpanContext().SpanID() go func() { if err := <-errChan; err != nil { diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index 9f689f76ede0..05bb4fa2ff55 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -842,10 +842,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { } func TestAsyncWebhook(t *testing.T) { - conf, reg := internal.NewFastRegistryWithMocks(t) - _ = conf - // conf.MustSet(ctx, config.ViperKeyClientHTTPNoPrivateIPRanges, true) - // conf.MustSet(ctx, config.ViperKeyClientHTTPPrivateIPExceptionURLs, []string{webhookReceiver.URL}) + _, reg := internal.NewFastRegistryWithMocks(t) logger := logrusx.New("kratos", "test") logHook := new(test.Hook) logger.Logger.Hooks.Add(logHook) @@ -866,6 +863,7 @@ func TestAsyncWebhook(t *testing.T) { } incomingCtx, incomingCancel := context.WithCancel(context.Background()) if deadline, ok := t.Deadline(); ok { + // cancel this context one second before test timeout for clean shutdown var cleanup context.CancelFunc incomingCtx, cleanup = context.WithDeadline(incomingCtx, deadline.Add(-time.Second)) defer cleanup() @@ -881,7 +879,6 @@ func TestAsyncWebhook(t *testing.T) { w.Write([]byte("ok")) })) t.Cleanup(webhookReceiver.Close) - // defer webhookReceiver.Close() wh := hook.NewWebHook(&whDeps, json.RawMessage(fmt.Sprintf(` { @@ -902,7 +899,7 @@ func TestAsyncWebhook(t *testing.T) { } // at this point, a goroutine is in the middle of the call to our test handler and waiting for a response incomingCancel() // simulate the incoming Kratos request having finished - testFor := time.After(200 * time.Millisecond) + timeout := time.After(200 * time.Millisecond) for done := false; !done; { if last := logHook.LastEntry(); last != nil { msg, err := last.String() @@ -911,7 +908,7 @@ func TestAsyncWebhook(t *testing.T) { } select { - case <-testFor: + case <-timeout: done = true case <-time.After(50 * time.Millisecond): // continue loop @@ -919,16 +916,17 @@ func TestAsyncWebhook(t *testing.T) { } logHook.Reset() close(blockHandlerOnExit) - testFor = time.After(200 * time.Millisecond) - for done := false; !done; { + timeout = time.After(200 * time.Millisecond) + for { if last := logHook.LastEntry(); last != nil { msg, err := last.String() require.NoError(t, err) assert.Contains(t, msg, "Webhook request succeeded") + break } select { - case <-testFor: - done = true + case <-timeout: + t.Fatal("timed out waiting for successful webhook completion") case <-time.After(50 * time.Millisecond): // continue loop }