From 5afd27290aec1729581b88342c9b5eba1444769b Mon Sep 17 00:00:00 2001 From: Maksim Terekhin Date: Tue, 16 Apr 2024 16:10:27 +0200 Subject: [PATCH] feat: Add go-kit Kafka transport --- pkg/transport/kafka/encode_decode.go | 25 ++ pkg/transport/kafka/publisher.go | 205 +++++++++ pkg/transport/kafka/publisher_test.go | 390 ++++++++++++++++++ pkg/transport/kafka/request_response_funcs.go | 78 ++++ pkg/transport/kafka/subscriber.go | 172 ++++++++ pkg/transport/kafka/subscriber_test.go | 113 +++++ 6 files changed, 983 insertions(+) create mode 100644 pkg/transport/kafka/encode_decode.go create mode 100644 pkg/transport/kafka/publisher.go create mode 100644 pkg/transport/kafka/publisher_test.go create mode 100644 pkg/transport/kafka/request_response_funcs.go create mode 100644 pkg/transport/kafka/subscriber.go create mode 100644 pkg/transport/kafka/subscriber_test.go diff --git a/pkg/transport/kafka/encode_decode.go b/pkg/transport/kafka/encode_decode.go new file mode 100644 index 0000000..86538e7 --- /dev/null +++ b/pkg/transport/kafka/encode_decode.go @@ -0,0 +1,25 @@ +package kafka + +import ( + "context" + + "github.com/twmb/franz-go/pkg/kgo" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an Kafka message. It is designed to be used in Kafka Subscribers. +type DecodeRequestFunc func(ctx context.Context, msg *kgo.Record) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed request object into +// an Kafka message object. It is designed to be used in Kafka Publishers. +type EncodeRequestFunc func(context.Context, *kgo.Record, interface{}) error + +// EncodeResponseFunc encodes the passed response object into +// a Kafka message object. It is designed to be used in Kafka Subscribers. +type EncodeResponseFunc func(context.Context, *kgo.Record, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from kafka +// response object. It's designed to be used in kafka publisher, for publisher-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// JSON decodes from the response payload to the concrete response type. +type DecodeResponseFunc func(context.Context, *kgo.Record) (response interface{}, err error) diff --git a/pkg/transport/kafka/publisher.go b/pkg/transport/kafka/publisher.go new file mode 100644 index 0000000..7269ce4 --- /dev/null +++ b/pkg/transport/kafka/publisher.go @@ -0,0 +1,205 @@ +package kafka + +import ( + "context" + "encoding/json" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/twmb/franz-go/pkg/kgo" +) + +const ( + defaultPublisherTimeout = 10 * time.Second +) + +// Publisher wraps single Kafka topic for message publishing +// and implements endpoint.Endpoint. +type Publisher struct { + handler Handler + topic string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc + after []PublisherResponseFunc + deliverer Deliverer + timeout time.Duration +} + +// NewPublisher constructs a new publisher for a single Kafka topic, +// which implements endpoint.Endpoint. +func NewPublisher( + handler Handler, + topic string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + handler: handler, + topic: topic, + deliverer: SyncDeliverer, + enc: enc, + dec: dec, + timeout: defaultPublisherTimeout, + } + for _, opt := range options { + opt(p) + } + + return p +} + +// PublisherOption sets an optional parameter for publishers. +type PublisherOption func(publisher *Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing publisher +// request before it's invoked. +func PublisherBefore(before ...RequestFunc) PublisherOption { + return func(p *Publisher) { + p.before = append(p.before, before...) + } +} + +// PublisherAfter adds one or more PublisherResponseFuncs, which are applied to the +// context after successful message publishing. +// This is useful for context-manipulation operations. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { + p.after = append(p.after, after...) + } +} + +// PublisherDeliverer sets the deliverer function that the Publisher invokes. +func PublisherDeliverer(deliverer Deliverer) PublisherOption { + return func(p *Publisher) { p.deliverer = deliverer } +} + +// PublisherTimeout sets the available timeout for a kafka request. +func PublisherTimeout(timeout time.Duration) PublisherOption { + return func(p *Publisher) { p.timeout = timeout } +} + +// Endpoint returns a usable endpoint that invokes message publishing. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + msg := &kgo.Record{ + Topic: p.topic, + } + + if err := p.enc(ctx, msg, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, msg) + } + + event, err := p.deliverer(ctx, p, msg) + if err != nil { + return nil, err + } + + for _, f := range p.after { + ctx = f(ctx, event) + } + + response, err := p.dec(ctx, event) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// Deliverer is invoked by the Publisher to publish the specified Message, and to +// retrieve the appropriate response Event object. +type Deliverer func( + context.Context, + Publisher, + *kgo.Record, +) (*kgo.Record, error) + +// SyncDeliverer is a deliverer that publishes the specified message +// and returns the first object. +// If the context times out while waiting for a reply, an error will be returned. +func SyncDeliverer(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + results := pub.handler.ProduceSync(ctx, msg) + + if len(results) > 0 && results[0].Err != nil { + return nil, results[0].Err + } + + return results[0].Record, nil +} + +// AsyncDeliverer delivers the supplied message and +// returns a nil response. +// +// When using this deliverer please ensure that the supplied DecodeResponseFunc and +// PublisherResponseFunc are able to handle nil-type responses. +// +// AsyncDeliverer will produce the message with the context detached due to the fact that actual +// message producing is called asynchronously (another goroutine) and at that time original context might be +// already canceled causing the producer to fail. The detached context will include values attached to the original +// context, but deadline and cancel will be reset. To provide a context for asynchronous deliverer please +// use AsyncDelivererCtx function instead. +func AsyncDeliverer(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + pub.handler.Produce(detach{ctx: ctx}, msg, nil) + + return nil, nil +} + +// AsyncDelivererCtx delivers the supplied message and +// returns a nil response. +// +// When using this deliverer please ensure that the supplied DecodeResponseFunc and +// PublisherResponseFunc are able to handle nil-type responses. +func AsyncDelivererCtx(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + pub.handler.Produce(ctx, msg, nil) + + return nil, nil +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object to the Message value. +// Many services can use it as a sensible default. +func EncodeJSONRequest(_ context.Context, msg *kgo.Record, request interface{}) error { + rawJSON, err := json.Marshal(request) + if err != nil { + return err + } + + msg.Value = rawJSON + + return nil +} + +// Handler is a handler interface to make testing possible. +// It is highly recommended to use *kafka.Producer as the interface implementation. +type Handler interface { + Produce(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) + ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults +} + +type detach struct { + ctx context.Context +} + +func (d detach) Deadline() (time.Time, bool) { + return time.Time{}, false +} +func (d detach) Done() <-chan struct{} { + return nil +} +func (d detach) Err() error { + return nil +} + +func (d detach) Value(key interface{}) interface{} { + return d.ctx.Value(key) +} diff --git a/pkg/transport/kafka/publisher_test.go b/pkg/transport/kafka/publisher_test.go new file mode 100644 index 0000000..9a79016 --- /dev/null +++ b/pkg/transport/kafka/publisher_test.go @@ -0,0 +1,390 @@ +package kafka + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/twmb/franz-go/pkg/kgo" + + sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" + sdkrequestidcontext "github.com/scribd/go-sdk/pkg/context/requestid" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +const ( + errString = "err" +) + +type ( + mockHandler struct { + producePromise func(ctx context.Context, record *kgo.Record, err error) + deliverAfter time.Duration + } + + testReq struct { + A int `json:"a"` + } + testRes struct { + A int `json:"a"` + } +) + +func (h *mockHandler) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + var results kgo.ProduceResults + done := make(chan struct{}) + + h.Produce(ctx, rs[0], func(record *kgo.Record, err error) { + results = append(results, kgo.ProduceResult{Record: record, Err: err}) + close(done) + }) + + <-done + return results +} + +func (h *mockHandler) Produce(ctx context.Context, rec *kgo.Record, promise func(record *kgo.Record, err error)) { + fn := func(rec *kgo.Record, err error) { + if promise != nil { + promise(rec, err) + } + if h.producePromise != nil { + h.producePromise(ctx, rec, err) + } + } + + go func() { + select { + case <-ctx.Done(): + fn(rec, ctx.Err()) + case <-time.After(h.deliverAfter): + if fn != nil { + fn(rec, nil) + } + } + }() +} + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + h := &mockHandler{} + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return errors.New(errString) }, + func(context.Context, *kgo.Record) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := errString, err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + h := &mockHandler{} + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, errors.New(errString) + }, + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := errString, err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestPublisherTimeout ensures that the publisher timeout mechanism works. +func TestPublisherTimeout(t *testing.T) { + h := &mockHandler{ + deliverAfter: time.Second, + } + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSuccessfulPublisher(t *testing.T) { + mockReq := testReq{1} + mockRes := testRes{ + A: 1, + } + _, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + h := &mockHandler{} + + pub := NewPublisher( + h, + "test", + testReqEncoder, + testResMessageDecoder, + ) + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + res, pubErr := pub.Endpoint()(context.Background(), mockReq) + if pubErr != nil { + errChan <- pubErr + } else { + resChan <- res + } + }() + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.A, res.A; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +// TestSendAndForgetPublisher tests that the AsyncDeliverer is working +func TestAsyncPublisher(t *testing.T) { + finishChan := make(chan struct{}) + contextKey, contextValue := "test", "test" + h := &mockHandler{ + producePromise: func(ctx context.Context, record *kgo.Record, err error) { + if err != nil { + t.Fatal(err) + } + if want, have := contextValue, ctx.Value(contextKey); want != have { + t.Errorf("want %s, have %s", want, have) + } + finishChan <- struct{}{} + }, + } + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(ctx context.Context, rec *kgo.Record) (response interface{}, err error) { + val := ctx.Value(contextValue).(string) + assert.Equal(t, contextValue, val) + + return struct{}{}, nil + }, + PublisherDeliverer(AsyncDeliverer), + PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error) + go func() { + ctx := context.WithValue(context.Background(), contextKey, contextValue) + ctx, cancel := context.WithCancel(ctx) + cancel() + _, pubErr := pub.Endpoint()(ctx, struct{}{}) + if pubErr != nil { + errChan <- pubErr + } + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} + +func TestSetRequestID(t *testing.T) { + h := &mockHandler{} + + mockReq := testReq{1} + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherBefore(SetRequestID()), + PublisherDeliverer(func(ctx context.Context, publisher Publisher, message *kgo.Record) (*kgo.Record, error) { + r, err := sdkrequestidcontext.Extract(ctx) + require.NotNil(t, r) + require.Nil(t, err) + + return nil, nil + }), + ) + + _, err := pub.Endpoint()(context.Background(), mockReq) + require.Nil(t, err) +} + +func TestSetLogger(t *testing.T) { + h := &mockHandler{} + + mockReq := testReq{1} + + // Inject this "owned" buffer as Output in the logger wrapped by + // the loggingMiddleware under test. + var buffer bytes.Buffer + + config := &sdklogger.Config{ + ConsoleEnabled: true, + ConsoleJSONFormat: true, + ConsoleLevel: "info", + FileEnabled: false, + } + l, err := sdklogger.NewBuilder(config).BuildTestLogger(&buffer) + require.Nil(t, err) + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherBefore(SetLogger(l)), + PublisherDeliverer(func(ctx context.Context, publisher Publisher, message *kgo.Record) (*kgo.Record, error) { + l, ctxErr := sdkloggercontext.Extract(ctx) + require.NotNil(t, l) + require.Nil(t, ctxErr) + + l.Infof("test") + + return nil, nil + }), + PublisherAfter(func(ctx context.Context, ev *kgo.Record) context.Context { + l, ctxErr := sdkloggercontext.Extract(ctx) + require.NotNil(t, l) + require.Nil(t, ctxErr) + + return ctx + }), + ) + + _, err = pub.Endpoint()(context.Background(), mockReq) + require.Nil(t, err) + + var fields map[string]interface{} + err = json.Unmarshal(buffer.Bytes(), &fields) + require.Nil(t, err) + + assert.NotEmpty(t, fields["pubsub"]) + assert.NotEmpty(t, fields["dd"]) + + var pubsub = (fields["pubsub"]).(map[string]interface{}) + assert.NotNil(t, pubsub["request_id"]) +} + +func testReqEncoder(_ context.Context, m *kgo.Record, request interface{}) error { + req, ok := request.(testReq) + if !ok { + return errors.New("type assertion failure") + } + b, err := json.Marshal(req) + if err != nil { + return err + } + m.Value = b + return nil +} + +func testResMessageDecoder(_ context.Context, m *kgo.Record) (interface{}, error) { + return testResDecoder(m.Value) +} + +func testResDecoder(b []byte) (interface{}, error) { + var obj testRes + err := json.Unmarshal(b, &obj) + return obj, err +} diff --git a/pkg/transport/kafka/request_response_funcs.go b/pkg/transport/kafka/request_response_funcs.go new file mode 100644 index 0000000..9fdef25 --- /dev/null +++ b/pkg/transport/kafka/request_response_funcs.go @@ -0,0 +1,78 @@ +package kafka + +import ( + "context" + + "github.com/google/uuid" + "github.com/twmb/franz-go/pkg/kgo" + + sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" + sdkmetricscontext "github.com/scribd/go-sdk/pkg/context/metrics" + sdkrequestidcontext "github.com/scribd/go-sdk/pkg/context/requestid" + sdkinstrumentation "github.com/scribd/go-sdk/pkg/instrumentation" + sdklogger "github.com/scribd/go-sdk/pkg/logger" + sdkmetrics "github.com/scribd/go-sdk/pkg/metrics" +) + +// RequestFunc may take information from a Kafka message and put it into a +// request context. In Subscribers, RequestFuncs are executed prior to invoking the +// endpoint. +type RequestFunc func(ctx context.Context, msg *kgo.Record) context.Context + +// SubscriberResponseFunc may take information from a request context and use it to +// manipulate a Publisher. SubscriberResponseFuncs are only executed in +// consumers, after invoking the endpoint but prior to publishing a reply. +type SubscriberResponseFunc func(ctx context.Context, response interface{}) context.Context + +// PublisherResponseFunc may take information from a request context. +// PublisherResponseFunc are only executed in producers, after a request has been produced. +type PublisherResponseFunc func(ctx context.Context, msg *kgo.Record) context.Context + +// SetMetrics returns RequestFunc that sets the Metrics client to the request context. +func SetMetrics(m sdkmetrics.Metrics) RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + return sdkmetricscontext.ToContext(ctx, m) + } +} + +// SetRequestID returns RequestFunc that sets RequestID to the request context if not previously set. +func SetRequestID() RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + _, err := sdkrequestidcontext.Extract(ctx) + if err != nil { + if uuidObject, err := uuid.NewRandom(); err == nil { + requestID := uuidObject.String() + return sdkrequestidcontext.ToContext(ctx, requestID) + } + } + + return ctx + } +} + +// SetLogger returns RequestFunc that sets SDK Logger to the request context. +// It will also try to setup context values to the logger fields. +func SetLogger(l sdklogger.Logger) RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + logContext := sdkinstrumentation.TraceLogs(ctx) + + requestID, err := sdkrequestidcontext.Extract(ctx) + if err != nil { + l.WithFields(sdklogger.Fields{ + "error": err.Error(), + }).Tracef("Could not retrieve request id from the context") + } + + logger := l.WithFields(sdklogger.Fields{ + "pubsub": sdklogger.Fields{ + "request_id": requestID, + }, + "dd": sdklogger.Fields{ + "trace_id": logContext.TraceID, + "span_id": logContext.SpanID, + }, + }) + + return sdkloggercontext.ToContext(ctx, logger) + } +} diff --git a/pkg/transport/kafka/subscriber.go b/pkg/transport/kafka/subscriber.go new file mode 100644 index 0000000..17dac54 --- /dev/null +++ b/pkg/transport/kafka/subscriber.go @@ -0,0 +1,172 @@ +package kafka + +import ( + "context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" + "github.com/twmb/franz-go/pkg/kgo" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + + kafkasdk "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdkkafka "github.com/scribd/go-sdk/pkg/pubsub/kafka" +) + +// Subscriber wraps an endpoint and provides a handler for kafka messages. +type Subscriber struct { + e endpoint.Endpoint + dec DecodeRequestFunc + before []RequestFunc + after []SubscriberResponseFunc + finalizer []SubscriberFinalizerFunc + errorHandler transport.ErrorHandler + errorEncoder ErrorEncoder +} + +// NewSubscriber constructs a new subscriber provides a handler for kafka messages. +func NewSubscriber( + e endpoint.Endpoint, + dec DecodeRequestFunc, + opts ...SubscriberOption, +) *Subscriber { + c := &Subscriber{ + e: e, + dec: dec, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, opt := range opts { + opt(c) + } + + return c +} + +// SubscriberOption sets an optional parameter for subscribers. +type SubscriberOption func(consumer *Subscriber) + +// SubscriberBefore functions are executed on the subscriber message object +// before the request is decoded. +func SubscriberBefore(before ...RequestFunc) SubscriberOption { + return func(c *Subscriber) { + c.before = append(c.before, before...) + } +} + +// SubscriberAfter functions are executed on the subscriber reply after the +// endpoint is invoked, but before anything is published to the reply. +func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { + return func(c *Subscriber) { + c.after = append(c.after, after...) + } +} + +// SubscriberErrorEncoder is used to encode errors to the subscriber reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { + return func(s *Subscriber) { s.errorEncoder = ee } +} + +// SubscriberErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. +func SubscriberErrorHandler(errorHandler transport.ErrorHandler) SubscriberOption { + return func(c *Subscriber) { + c.errorHandler = errorHandler + } +} + +// SubscriberFinalizer is executed at the end of every message processing. +// By default, no finalizer is registered. +func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption { + return func(c *Subscriber) { + c.finalizer = append(c.finalizer, f...) + } +} + +// ServeMsg provides kafka.MsgHandler. +func (s Subscriber) ServeMsg(h Handler) sdkkafka.MsgHandler { + return func(msg *kgo.Record) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if len(s.finalizer) > 0 { + defer func() { + for _, f := range s.finalizer { + f(ctx, msg) + } + }() + } + + for _, f := range s.before { + ctx = f(ctx, msg) + } + + request, err := s.dec(ctx, msg) + if err != nil { + s.errorEncoder(ctx, err, msg, h) + s.errorHandler.Handle(ctx, err) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.errorEncoder(ctx, err, msg, h) + s.errorHandler.Handle(ctx, err) + return + } + + for _, f := range s.after { + ctx = f(ctx, response) + } + } +} + +// SubscriberFinalizerFunc can be used to perform work at the end of message processing, +// after the response has been constructed. The principal +// intended use is for request logging. +type SubscriberFinalizerFunc func(ctx context.Context, msg *kgo.Record) + +// ErrorEncoder is responsible for encoding an error to the subscriber reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, + err error, msg *kgo.Record, h Handler) + +// DefaultErrorEncoder simply ignores the message. +func DefaultErrorEncoder(ctx context.Context, + err error, msg *kgo.Record, h Handler) { +} + +// NewInstrumentedSubscriber constructs a new subscriber provides a handler for kafka messages. +// It also instruments the subscriber with datadog tracing. +func NewInstrumentedSubscriber(e endpoint.Endpoint, dec DecodeRequestFunc, opts ...SubscriberOption) *Subscriber { + options := []SubscriberOption{ + SubscriberBefore(startMessageHandlerTrace), + SubscriberFinalizer(finishMessageHandlerTrace), + } + + options = append(options, opts...) + + return NewSubscriber(e, dec, options...) +} + +func startMessageHandlerTrace(ctx context.Context, msg *kgo.Record) context.Context { + if spanctx, err := tracer.Extract(kafkasdk.NewMessageCarrier(msg)); err == nil { + span := tracer.StartSpan("kafka.msghandler", tracer.ChildOf(spanctx)) + + ctx = tracer.ContextWithSpan(ctx, span) + } + + return ctx +} + +func finishMessageHandlerTrace(ctx context.Context, msg *kgo.Record) { + if span, ok := tracer.SpanFromContext(ctx); ok { + span.Finish() + } +} diff --git a/pkg/transport/kafka/subscriber_test.go b/pkg/transport/kafka/subscriber_test.go new file mode 100644 index 0000000..482f400 --- /dev/null +++ b/pkg/transport/kafka/subscriber_test.go @@ -0,0 +1,113 @@ +package kafka + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/twmb/franz-go/pkg/kgo" +) + +// TestSubscriberBadDecode checks if decoder errors are handled properly. +func TestSubscriberBadDecode(t *testing.T) { + errCh := make(chan error, 1) + + sub := NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *kgo.Record) (interface{}, error) { return nil, errors.New("err!") }, + SubscriberErrorEncoder(createTestErrorEncoder(errCh)), + ) + + sub.ServeMsg(nil)(&kgo.Record{}) + + var err error + select { + case err = <-errCh: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadEndpoint checks if endpoint errors are handled properly. +func TestSubscriberBadEndpoint(t *testing.T) { + errCh := make(chan error, 1) + + sub := NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *kgo.Record) (interface{}, error) { return struct{}{}, nil }, + SubscriberErrorEncoder(createTestErrorEncoder(errCh)), + ) + + sub.ServeMsg(nil)(&kgo.Record{}) + + var err error + select { + case err = <-errCh: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSubscriberSuccess(t *testing.T) { + obj := testReq{A: 1} + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + + sub := NewSubscriber( + testEndpoint, + testReqDecoder, + SubscriberAfter(func(ctx context.Context, response interface{}) context.Context { + res := response.(testRes) + if res.A != 2 { + t.Errorf("got wrong result: %d", res.A) + } + + return ctx + }), + ) + + // golangci-lint v1.44.2 in linux reports that variable 'b' declared but unused + // in case of usage of self-invoked function. Probably need an investigation + handler := sub.ServeMsg(nil) + handler(&kgo.Record{ + Value: b, + }) +} + +func createTestErrorEncoder(ch chan error) ErrorEncoder { + return func(ctx context.Context, err error, msg *kgo.Record, h Handler) { + ch <- err + } +} + +func testReqDecoder(_ context.Context, m *kgo.Record) (interface{}, error) { + var obj testReq + err := json.Unmarshal(m.Value, &obj) + return obj, err +} + +func testEndpoint(_ context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, errors.New("type assertion error") + } + + res := testRes{ + A: req.A + 1, + } + return res, nil +}