diff --git a/go.mod b/go.mod index 06feacd..57c664e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/DataDog/datadog-go v4.8.2+incompatible github.com/aws/aws-sdk-go v1.34.28 github.com/getsentry/sentry-go v0.12.0 + github.com/go-kit/kit v0.9.0 github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 github.com/magefile/mage v1.15.0 @@ -16,6 +17,7 @@ require ( github.com/spf13/viper v1.10.1 github.com/stretchr/testify v1.8.4 github.com/twmb/franz-go v1.12.1 + github.com/twmb/franz-go/pkg/kmsg v1.4.0 google.golang.org/grpc v1.60.1 google.golang.org/protobuf v1.32.0 gopkg.in/DataDog/dd-trace-go.v1 v1.47.0 @@ -38,6 +40,7 @@ require ( github.com/dgraph-io/ristretto v0.1.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/golang/glog v1.1.2 // indirect github.com/golang/protobuf v1.5.3 // indirect @@ -67,7 +70,6 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/tinylib/msgp v1.1.6 // indirect - github.com/twmb/franz-go/pkg/kmsg v1.4.0 // indirect go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect golang.org/x/net v0.20.0 // indirect diff --git a/go.sum b/go.sum index 4bbf6c9..ff12e14 100644 --- a/go.sum +++ b/go.sum @@ -70,10 +70,16 @@ github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= diff --git a/pkg/pubsub/kafka/config.go b/pkg/pubsub/kafka/config.go new file mode 100644 index 0000000..d043d45 --- /dev/null +++ b/pkg/pubsub/kafka/config.go @@ -0,0 +1,154 @@ +package kafka + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/twmb/franz-go/pkg/kgo" + awssasl "github.com/twmb/franz-go/pkg/sasl/aws" + "github.com/twmb/franz-go/pkg/sasl/plain" + + sdklogger "github.com/scribd/go-sdk/pkg/logger" + "github.com/scribd/go-sdk/pkg/pubsub" +) + +// Config provides a common configuration for Kafka PubSub clients. +type Config struct { + // Application name that will be used in a serviceName provided to tracer spans + ApplicationName string + // Kafka configuration provided by go-sdk + KafkaConfig pubsub.Kafka + // AWS session reference, it will be used in case AWS MSK IAM authentication mechanism is used + AwsSession *session.Session + // MsgHandler is a function that will be called when a message is received + MsgHandler MsgHandler + Logger sdklogger.Logger +} + +const tlsConnectionTimeout = 10 * time.Second + +func newConfig(c Config, opts ...kgo.Opt) ([]kgo.Opt, error) { + options := []kgo.Opt{ + kgo.SeedBrokers(c.KafkaConfig.BrokerUrls...), + kgo.ClientID(c.KafkaConfig.ClientId), + } + + if c.KafkaConfig.SASL.Enabled { + switch c.KafkaConfig.SASLMechanism() { + case pubsub.Plain: + options = append(options, getPlainSaslOption(c.KafkaConfig.SASL)) + case pubsub.AWSMskIam: + options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession)) + } + } + + if c.KafkaConfig.TLS.Enabled || c.KafkaConfig.SecurityProtocol == "ssl" { + var caCertPool *x509.CertPool + + if c.KafkaConfig.TLS.Ca != "" { + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM([]byte(c.KafkaConfig.TLS.Ca)) + } + + var certificates []tls.Certificate + if c.KafkaConfig.TLS.Cert != "" && c.KafkaConfig.TLS.CertKey != "" { + cert, err := tls.X509KeyPair([]byte(c.KafkaConfig.TLS.Cert), []byte(c.KafkaConfig.TLS.CertKey)) + if err != nil { + return nil, err + } + certificates = []tls.Certificate{cert} + } + + if c.KafkaConfig.Cert != "" && c.KafkaConfig.CertKey != "" { + cert, err := tls.X509KeyPair([]byte(c.KafkaConfig.Cert), []byte(c.KafkaConfig.CertKey)) + if err != nil { + return nil, err + } + certificates = []tls.Certificate{cert} + } + + var skipTLSVerify bool + if c.KafkaConfig.TLS.InsecureSkipVerify || !c.KafkaConfig.SSLVerificationEnabled { + skipTLSVerify = true + } + + tlsDialer := &tls.Dialer{ + NetDialer: &net.Dialer{Timeout: tlsConnectionTimeout}, + Config: &tls.Config{ + InsecureSkipVerify: skipTLSVerify, + Certificates: certificates, + RootCAs: caCertPool, + }, + } + + options = append(options, kgo.Dialer(tlsDialer.DialContext)) + } + + options = append(options, opts...) + + return options, nil +} + +func getPlainSaslOption(saslConf pubsub.SASL) kgo.Opt { + return kgo.SASL(plain.Auth{ + User: saslConf.Username, + Pass: saslConf.Password, + }.AsMechanism()) +} + +func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kgo.Opt { + var opt kgo.Opt + + // no AWS session provided + if s == nil { + opt = kgo.SASL(awssasl.Auth{ + AccessKey: iamConf.AccessKey, + SecretKey: iamConf.SecretKey, + SessionToken: iamConf.SessionToken, + UserAgent: iamConf.UserAgent, + }.AsManagedStreamingIAMMechanism()) + } else { + opt = kgo.SASL( + awssasl.ManagedStreamingIAM(func(ctx context.Context) (awssasl.Auth, error) { + // If assumable role is not provided, we try to get credentials from the provided AWS session + if iamConf.AssumableRole == "" { + val, err := s.Config.Credentials.Get() + if err != nil { + return awssasl.Auth{}, err + } + + return awssasl.Auth{ + AccessKey: val.AccessKeyID, + SecretKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + UserAgent: iamConf.UserAgent, + }, nil + } + + svc := sts.New(s) + + res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{ + RoleArn: &iamConf.AssumableRole, + RoleSessionName: &iamConf.SessionName, + }) + if stsErr != nil { + return awssasl.Auth{}, stsErr + } + + return awssasl.Auth{ + AccessKey: *res.Credentials.AccessKeyId, + SecretKey: *res.Credentials.SecretAccessKey, + SessionToken: *res.Credentials.SessionToken, + UserAgent: iamConf.UserAgent, + }, nil + }), + ) + } + + return opt +} diff --git a/pkg/pubsub/kafka/partitionconsumer.go b/pkg/pubsub/kafka/partitionconsumer.go new file mode 100644 index 0000000..824f66c --- /dev/null +++ b/pkg/pubsub/kafka/partitionconsumer.go @@ -0,0 +1,42 @@ +package kafka + +import ( + "context" + + "github.com/twmb/franz-go/pkg/kgo" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type pconsumer struct { + pool *pool + + quit chan struct{} + done chan struct{} + recs chan *sdkkafka.FetchPartition +} + +func (pc *pconsumer) consume(cl *kgo.Client, logger sdklogger.Logger, shouldCommit bool, handler func(*kgo.Record)) { + defer close(pc.done) + + for { + select { + case <-pc.quit: + return + case p := <-pc.recs: + p.EachRecord(func(rec *kgo.Record) { + pc.pool.Schedule(func() { + defer p.ConsumeRecord(rec) + + handler(rec) + }) + }) + if shouldCommit { + if err := cl.CommitRecords(context.Background(), p.Records...); err != nil { + logger.WithError(err).Errorf("Partition consumer failed to commit records") + } + } + } + } +} diff --git a/pkg/pubsub/kafka/pool.go b/pkg/pubsub/kafka/pool.go new file mode 100644 index 0000000..e80ce1b --- /dev/null +++ b/pkg/pubsub/kafka/pool.go @@ -0,0 +1,33 @@ +package kafka + +type pool struct { + sem chan struct{} + work chan func() +} + +func newPool(size int) *pool { + p := &pool{ + sem: make(chan struct{}, size), + work: make(chan func()), + } + + return p +} + +func (p *pool) Schedule(task func()) { + select { + case p.work <- task: + return + case p.sem <- struct{}{}: + go p.worker(task) + } +} + +func (p *pool) worker(task func()) { + defer func() { <-p.sem }() + + for { + task() + task = <-p.work + } +} diff --git a/pkg/pubsub/kafka/publisher.go b/pkg/pubsub/kafka/publisher.go new file mode 100644 index 0000000..69377c1 --- /dev/null +++ b/pkg/pubsub/kafka/publisher.go @@ -0,0 +1,85 @@ +package kafka + +import ( + "context" + "fmt" + "time" + + "github.com/twmb/franz-go/pkg/kgo" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" +) + +type ( + Publisher struct { + producer *sdkkafka.Client + } +) + +const ( + defaultFlushTimeout = time.Second * 10 + + publisherServiceNameSuffix = "pubsub-publisher" +) + +// NewPublisher is a tiny wrapper around the go-sdk kafka.Client and provides API to Publish kafka messages. +func NewPublisher(c Config, opts ...kgo.Opt) (*Publisher, error) { + serviceName := fmt.Sprintf("%s-%s", c.ApplicationName, publisherServiceNameSuffix) + + cfg, err := newConfig(c, opts...) + if err != nil { + return nil, err + } + + cfg = append(cfg, []kgo.Opt{ + kgo.ProduceRequestTimeout(c.KafkaConfig.Publisher.WriteTimeout), + kgo.RecordRetries(c.KafkaConfig.Publisher.MaxAttempts), + }...) + + producer, err := sdkkafka.NewClient(cfg, sdkkafka.WithServiceName(serviceName)) + if err != nil { + return nil, err + } + + return &Publisher{producer: producer}, nil +} + +// Publish publishes kgo.Record message. +func (p *Publisher) Publish(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) { + p.producer.Produce(ctx, rec, fn) +} + +// Produce is an alias to Publish to satisfy kafka go-kit transport. +func (p *Publisher) Produce(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) { + p.Publish(ctx, rec, fn) +} + +// ProduceSync publishes kgo.Record messages synchronously. +func (p *Publisher) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + return p.producer.ProduceSync(ctx, rs...) +} + +// GetKafkaProducer returns underlying kafka.Producer for fine-grained tuning purposes. +func (p *Publisher) GetKafkaProducer() *sdkkafka.Client { + return p.producer +} + +// Stop flushes and waits for outstanding messages and requests to complete delivery. +// It also closes a Producer instance. +func (p *Publisher) Stop(ctx context.Context) error { + if _, deadlineSet := ctx.Deadline(); !deadlineSet { + timeoutCtx, cancel := context.WithTimeout(ctx, defaultFlushTimeout) + defer cancel() + + ctx = timeoutCtx + } + + err := p.producer.Flush(ctx) + if err != nil { + return err + } + + p.producer.Close() + + return nil +} diff --git a/pkg/pubsub/kafka/subscriber.go b/pkg/pubsub/kafka/subscriber.go new file mode 100644 index 0000000..d256f8e --- /dev/null +++ b/pkg/pubsub/kafka/subscriber.go @@ -0,0 +1,214 @@ +package kafka + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/twmb/franz-go/pkg/kerr" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/pkg/kmsg" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type ( + Subscriber struct { + logger sdklogger.Logger + consumer *sdkkafka.Client + autoCommitDisabled bool + mu sync.Mutex + consumers map[string]map[int32]pconsumer + handler func(rec *kgo.Record) + numWorkers int + maxRecords int + } + MsgHandler func(msg *kgo.Record) +) + +const ( + subscriberServiceNameSuffix = "pubsub-subscriber" + + defaultMaxRecords = 10000 +) + +// NewSubscriber is a tiny wrapper around the sdk kafka.Client and provides API to subscribe to a kafka topic. +func NewSubscriber(c Config, opts ...kgo.Opt) (*Subscriber, error) { + serviceName := fmt.Sprintf("%s-%s", c.ApplicationName, subscriberServiceNameSuffix) + + cfg, err := newConfig(c, opts...) + if err != nil { + return nil, err + } + + autoCommitDisabled := !c.KafkaConfig.Subscriber.AutoCommit.Enabled + if autoCommitDisabled { + cfg = append(cfg, kgo.DisableAutoCommit()) + } + if c.KafkaConfig.Subscriber.BlockRebalance { + cfg = append(cfg, kgo.BlockRebalanceOnPoll()) + } + + s := &Subscriber{ + logger: c.Logger, + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: c.KafkaConfig.Subscriber.Workers, + handler: c.MsgHandler, + autoCommitDisabled: autoCommitDisabled, + maxRecords: c.KafkaConfig.Subscriber.MaxRecords, + } + + if s.maxRecords == 0 { + s.maxRecords = defaultMaxRecords + } + + cfg = append(cfg, []kgo.Opt{ + kgo.ConsumerGroup(c.KafkaConfig.Subscriber.GroupId), + kgo.ConsumeTopics(c.KafkaConfig.Subscriber.Topic), + kgo.OnPartitionsLost(s.lost), + kgo.OnPartitionsRevoked(s.revoked), + kgo.OnPartitionsAssigned(s.assigned), + }...) + + client, err := sdkkafka.NewClient(cfg, sdkkafka.WithServiceName(serviceName)) + if err != nil { + return nil, err + } + s.consumer = client + + return s, nil +} + +// Subscribe subscribes to a configured topic and reads messages. +// Returns unbuffered channel to inspect possible errors. +func (s *Subscriber) Subscribe(ctx context.Context) chan error { + ch := make(chan error) + + go func() { + defer close(ch) + + for { + fetches := s.consumer.PollRecords(ctx, s.maxRecords) + if fetches.IsClientClosed() { + return + } + + if fetches.Err() != nil { + var containsFatalErr bool + + fetches.EachError(func(_ string, _ int32, err error) { + if !containsFatalErr { + containsFatalErr = isFatalFetchError(err) + } + ch <- err + }) + + if containsFatalErr { + return + } + } + + fetches.EachTopic(func(t kgo.FetchTopic) { + s.mu.Lock() + tconsumers := s.consumers[t.Topic] + s.mu.Unlock() + + if tconsumers == nil { + return + } + t.EachPartition(func(p kgo.FetchPartition) { + pc, ok := tconsumers[p.Partition] + if !ok { + return + } + select { + case pc.recs <- s.consumer.WrapFetchPartition(ctx, p): + case <-pc.quit: + } + }) + }) + if kgoClient, ok := s.consumer.KafkaClient.(*kgo.Client); ok { + // this call does nothing in case rebalance is not blocked + kgoClient.AllowRebalance() + } + } + }() + + return ch +} + +func (s *Subscriber) assigned(_ context.Context, cl *kgo.Client, assigned map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + for topic, partitions := range assigned { + if s.consumers[topic] == nil { + s.consumers[topic] = make(map[int32]pconsumer) + } + for _, partition := range partitions { + pc := pconsumer{ + quit: make(chan struct{}), + recs: make(chan *sdkkafka.FetchPartition), + pool: newPool(s.numWorkers), + done: make(chan struct{}), + } + s.consumers[topic][partition] = pc + go pc.consume(cl, s.logger, s.autoCommitDisabled, s.handler) + } + } +} + +func (s *Subscriber) stopConsumers(lost map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + var wg sync.WaitGroup + defer wg.Wait() + + for topic, partitions := range lost { + ptopics := s.consumers[topic] + for _, partition := range partitions { + pc := ptopics[partition] + delete(ptopics, partition) + + if len(ptopics) == 0 { + delete(s.consumers, topic) + } + close(pc.quit) + wg.Add(1) + go func() { <-pc.done; wg.Done() }() + } + } +} + +func (s *Subscriber) lost(_ context.Context, _ *kgo.Client, lost map[string][]int32) { + s.stopConsumers(lost) +} + +func (s *Subscriber) revoked(ctx context.Context, cl *kgo.Client, lost map[string][]int32) { + s.stopConsumers(lost) + if !s.autoCommitDisabled { + cl.CommitOffsetsSync(ctx, cl.MarkedOffsets(), + func(cl *kgo.Client, _ *kmsg.OffsetCommitRequest, _ *kmsg.OffsetCommitResponse, err error) { + if err != nil { + s.logger.WithError(err).Errorf("Revoke commit failed") + } + }, + ) + } +} + +func (s *Subscriber) Unsubscribe() error { + s.consumer.Close() + + return nil +} + +func isFatalFetchError(err error) bool { + var kafkaErr *kerr.Error + + return errors.As(err, &kafkaErr) +} diff --git a/pkg/pubsub/kafka/subscriber_test.go b/pkg/pubsub/kafka/subscriber_test.go new file mode 100644 index 0000000..5e72c6b --- /dev/null +++ b/pkg/pubsub/kafka/subscriber_test.go @@ -0,0 +1,404 @@ +package kafka + +import ( + "bytes" + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/twmb/franz-go/pkg/kerr" + "github.com/twmb/franz-go/pkg/kgo" + + "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type mockKafkaClient struct { + fetches chan kgo.Fetches +} + +func (c *mockKafkaClient) Produce(ctx context.Context, r *kgo.Record, promise func(*kgo.Record, error)) { +} + +func (c *mockKafkaClient) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + return nil +} + +func (c *mockKafkaClient) PollRecords(ctx context.Context, num int) kgo.Fetches { + return <-c.fetches +} + +func (c *mockKafkaClient) Flush(ctx context.Context) error { + return nil +} + +func (c *mockKafkaClient) Close() { + close(c.fetches) +} + +func TestSubscriber_Subscribe(t *testing.T) { + tests := []struct { + name string + fetches kgo.Fetches + wantErr bool + msgHandler func(record *kgo.Record) + wantTerminated bool + expectedErrCnt int + }{ + { + name: "success", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Records: []*kgo.Record{{Key: []byte("test")}}, + }, + }, + }, + }, + }, + }, + msgHandler: func(record *kgo.Record) { + assert.Equal(t, []byte("test"), record.Key) + }, + }, + { + name: "non-fatal error", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + }, + }, + }, + }, + }, + expectedErrCnt: 1, + wantErr: true, + }, + { + name: "non-fatal errors", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + { + Err: errors.New("test2"), + }, + }, + }, + }, + }, + }, + expectedErrCnt: 2, + }, + { + name: "fatal error (terminate subscriber)", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: kerr.BrokerIDNotRegistered, + }, + }, + }, + }, + }, + }, + expectedErrCnt: 1, + wantErr: true, + wantTerminated: true, + }, + { + name: "fatal error (closed client)", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Partition: 0, + Err: kgo.ErrClientClosed, + }, + }, + }, + }, + }, + }, + wantTerminated: true, + }, + { + name: "fatal and non-fatal errors", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + { + Err: kerr.BrokerIDNotRegistered, + }, + }, + }, + }, + }, + }, + wantErr: true, + expectedErrCnt: 2, + wantTerminated: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fetchesChan := make(chan kgo.Fetches, 1) + fetchesChan <- tt.fetches + + var wg sync.WaitGroup + wg.Add(1) + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: 1, + handler: func(rec *kgo.Record) { + defer wg.Done() + + tt.msgHandler(rec) + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), nil, getAssigns(1)) + + errCh := s.Subscribe(context.Background()) + if tt.msgHandler != nil { + wg.Wait() + } + + if tt.wantErr { + i := 0 + for err := range errCh { + i++ + assert.NotNil(t, err) + if i == tt.expectedErrCnt && !tt.wantTerminated { + break + } + } + assert.Equal(t, tt.expectedErrCnt, i) + } + + if tt.wantTerminated { + _, ok := <-errCh + assert.False(t, ok) + + err := s.Unsubscribe() + assert.Nil(t, err) + + _, ok = <-fetchesChan + assert.False(t, ok) + } + }) + } +} + +func TestSubscriber_ConcurrentSubscribe(t *testing.T) { + tests := []struct { + name string + numWorkers int + numPartitions int + }{ + { + name: "1 worker, 1 partition", + numWorkers: 1, + numPartitions: 1, + }, { + name: "1 worker, 10 partitions", + numWorkers: 1, + numPartitions: 10, + }, { + name: "10 workers, 10 partitions", + numWorkers: 10, + numPartitions: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fetches := generateFetches(tt.numPartitions) + fetchesChan := make(chan kgo.Fetches, len(fetches)) + fetchesChan <- fetches + + var wg sync.WaitGroup + wg.Add(100) + + var i uint64 + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: tt.numWorkers, + handler: func(rec *kgo.Record) { + atomic.AddUint64(&i, 1) + wg.Done() + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), nil, getAssigns(tt.numPartitions)) + + s.Subscribe(context.Background()) + + wg.Wait() + + assert.Equal(t, 100, int(i)) + }) + } +} + +func TestSubscriber_RevokePartition(t *testing.T) { + fetchesChan := make(chan kgo.Fetches, 1) + + fetch := kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Records: []*kgo.Record{{Key: []byte("test")}}, + }, + }, + }, + }, + }, + } + + fetchesChan <- fetch + + recv := make(chan struct{}, 1) + + l, err := logger() + + assert.Nil(t, err) + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + logger: l, + numWorkers: 1, + handler: func(rec *kgo.Record) { + recv <- struct{}{} + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), &kgo.Client{}, getAssigns(1)) + + s.Subscribe(context.Background()) + + select { + case <-recv: + case <-time.After(5 * time.Millisecond): + t.Fatal("expected to receive a message") + } + + // make sure we stop the partition consumer + s.revoked(context.Background(), &kgo.Client{}, getAssigns(1)) + + fetchesChan <- fetch + + select { + case <-recv: + t.Fatal("expected to consumer to be stopped") + case <-time.After(5 * time.Millisecond): + } + + err = s.Unsubscribe() + assert.Nil(t, err) +} + +func getAssigns(numPartitions int) map[string][]int32 { + assigns := make(map[string][]int32) + for i := 0; i < numPartitions; i++ { + assigns["test"] = append(assigns["test"], int32(i)) + } + return assigns +} + +func generateFetches(partitions int) kgo.Fetches { + fetches := kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + }, + }, + }, + } + + numRecords := 100 / partitions + for i := 0; i < partitions; i++ { + records := make([]*kgo.Record, numRecords) + for j := 0; j < numRecords; j++ { + records[j] = &kgo.Record{ + Value: []byte(fmt.Sprintf("test %d", j)), + } + } + + fetches[0].Topics[0].Partitions = append(fetches[0].Topics[0].Partitions, kgo.FetchPartition{ + Partition: int32(i), + Records: records, + }) + } + + return fetches +} + +func logger() (sdklogger.Logger, error) { + var buffer bytes.Buffer + return sdklogger.NewBuilder( + &sdklogger.Config{ + ConsoleEnabled: true, + ConsoleJSONFormat: true, + ConsoleLevel: "info", + FileEnabled: false, + }).BuildTestLogger(&buffer) +} diff --git a/pkg/transport/kafka/encode_decode.go b/pkg/transport/kafka/encode_decode.go new file mode 100644 index 0000000..86538e7 --- /dev/null +++ b/pkg/transport/kafka/encode_decode.go @@ -0,0 +1,25 @@ +package kafka + +import ( + "context" + + "github.com/twmb/franz-go/pkg/kgo" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an Kafka message. It is designed to be used in Kafka Subscribers. +type DecodeRequestFunc func(ctx context.Context, msg *kgo.Record) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed request object into +// an Kafka message object. It is designed to be used in Kafka Publishers. +type EncodeRequestFunc func(context.Context, *kgo.Record, interface{}) error + +// EncodeResponseFunc encodes the passed response object into +// a Kafka message object. It is designed to be used in Kafka Subscribers. +type EncodeResponseFunc func(context.Context, *kgo.Record, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from kafka +// response object. It's designed to be used in kafka publisher, for publisher-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// JSON decodes from the response payload to the concrete response type. +type DecodeResponseFunc func(context.Context, *kgo.Record) (response interface{}, err error) diff --git a/pkg/transport/kafka/publisher.go b/pkg/transport/kafka/publisher.go new file mode 100644 index 0000000..7269ce4 --- /dev/null +++ b/pkg/transport/kafka/publisher.go @@ -0,0 +1,205 @@ +package kafka + +import ( + "context" + "encoding/json" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/twmb/franz-go/pkg/kgo" +) + +const ( + defaultPublisherTimeout = 10 * time.Second +) + +// Publisher wraps single Kafka topic for message publishing +// and implements endpoint.Endpoint. +type Publisher struct { + handler Handler + topic string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc + after []PublisherResponseFunc + deliverer Deliverer + timeout time.Duration +} + +// NewPublisher constructs a new publisher for a single Kafka topic, +// which implements endpoint.Endpoint. +func NewPublisher( + handler Handler, + topic string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + handler: handler, + topic: topic, + deliverer: SyncDeliverer, + enc: enc, + dec: dec, + timeout: defaultPublisherTimeout, + } + for _, opt := range options { + opt(p) + } + + return p +} + +// PublisherOption sets an optional parameter for publishers. +type PublisherOption func(publisher *Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing publisher +// request before it's invoked. +func PublisherBefore(before ...RequestFunc) PublisherOption { + return func(p *Publisher) { + p.before = append(p.before, before...) + } +} + +// PublisherAfter adds one or more PublisherResponseFuncs, which are applied to the +// context after successful message publishing. +// This is useful for context-manipulation operations. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { + p.after = append(p.after, after...) + } +} + +// PublisherDeliverer sets the deliverer function that the Publisher invokes. +func PublisherDeliverer(deliverer Deliverer) PublisherOption { + return func(p *Publisher) { p.deliverer = deliverer } +} + +// PublisherTimeout sets the available timeout for a kafka request. +func PublisherTimeout(timeout time.Duration) PublisherOption { + return func(p *Publisher) { p.timeout = timeout } +} + +// Endpoint returns a usable endpoint that invokes message publishing. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + msg := &kgo.Record{ + Topic: p.topic, + } + + if err := p.enc(ctx, msg, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, msg) + } + + event, err := p.deliverer(ctx, p, msg) + if err != nil { + return nil, err + } + + for _, f := range p.after { + ctx = f(ctx, event) + } + + response, err := p.dec(ctx, event) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// Deliverer is invoked by the Publisher to publish the specified Message, and to +// retrieve the appropriate response Event object. +type Deliverer func( + context.Context, + Publisher, + *kgo.Record, +) (*kgo.Record, error) + +// SyncDeliverer is a deliverer that publishes the specified message +// and returns the first object. +// If the context times out while waiting for a reply, an error will be returned. +func SyncDeliverer(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + results := pub.handler.ProduceSync(ctx, msg) + + if len(results) > 0 && results[0].Err != nil { + return nil, results[0].Err + } + + return results[0].Record, nil +} + +// AsyncDeliverer delivers the supplied message and +// returns a nil response. +// +// When using this deliverer please ensure that the supplied DecodeResponseFunc and +// PublisherResponseFunc are able to handle nil-type responses. +// +// AsyncDeliverer will produce the message with the context detached due to the fact that actual +// message producing is called asynchronously (another goroutine) and at that time original context might be +// already canceled causing the producer to fail. The detached context will include values attached to the original +// context, but deadline and cancel will be reset. To provide a context for asynchronous deliverer please +// use AsyncDelivererCtx function instead. +func AsyncDeliverer(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + pub.handler.Produce(detach{ctx: ctx}, msg, nil) + + return nil, nil +} + +// AsyncDelivererCtx delivers the supplied message and +// returns a nil response. +// +// When using this deliverer please ensure that the supplied DecodeResponseFunc and +// PublisherResponseFunc are able to handle nil-type responses. +func AsyncDelivererCtx(ctx context.Context, pub Publisher, msg *kgo.Record) (*kgo.Record, error) { + pub.handler.Produce(ctx, msg, nil) + + return nil, nil +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object to the Message value. +// Many services can use it as a sensible default. +func EncodeJSONRequest(_ context.Context, msg *kgo.Record, request interface{}) error { + rawJSON, err := json.Marshal(request) + if err != nil { + return err + } + + msg.Value = rawJSON + + return nil +} + +// Handler is a handler interface to make testing possible. +// It is highly recommended to use *kafka.Producer as the interface implementation. +type Handler interface { + Produce(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) + ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults +} + +type detach struct { + ctx context.Context +} + +func (d detach) Deadline() (time.Time, bool) { + return time.Time{}, false +} +func (d detach) Done() <-chan struct{} { + return nil +} +func (d detach) Err() error { + return nil +} + +func (d detach) Value(key interface{}) interface{} { + return d.ctx.Value(key) +} diff --git a/pkg/transport/kafka/publisher_test.go b/pkg/transport/kafka/publisher_test.go new file mode 100644 index 0000000..9a79016 --- /dev/null +++ b/pkg/transport/kafka/publisher_test.go @@ -0,0 +1,390 @@ +package kafka + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/twmb/franz-go/pkg/kgo" + + sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" + sdkrequestidcontext "github.com/scribd/go-sdk/pkg/context/requestid" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +const ( + errString = "err" +) + +type ( + mockHandler struct { + producePromise func(ctx context.Context, record *kgo.Record, err error) + deliverAfter time.Duration + } + + testReq struct { + A int `json:"a"` + } + testRes struct { + A int `json:"a"` + } +) + +func (h *mockHandler) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + var results kgo.ProduceResults + done := make(chan struct{}) + + h.Produce(ctx, rs[0], func(record *kgo.Record, err error) { + results = append(results, kgo.ProduceResult{Record: record, Err: err}) + close(done) + }) + + <-done + return results +} + +func (h *mockHandler) Produce(ctx context.Context, rec *kgo.Record, promise func(record *kgo.Record, err error)) { + fn := func(rec *kgo.Record, err error) { + if promise != nil { + promise(rec, err) + } + if h.producePromise != nil { + h.producePromise(ctx, rec, err) + } + } + + go func() { + select { + case <-ctx.Done(): + fn(rec, ctx.Err()) + case <-time.After(h.deliverAfter): + if fn != nil { + fn(rec, nil) + } + } + }() +} + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + h := &mockHandler{} + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return errors.New(errString) }, + func(context.Context, *kgo.Record) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := errString, err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + h := &mockHandler{} + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, errors.New(errString) + }, + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := errString, err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestPublisherTimeout ensures that the publisher timeout mechanism works. +func TestPublisherTimeout(t *testing.T) { + h := &mockHandler{ + deliverAfter: time.Second, + } + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, pubErr := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- pubErr + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSuccessfulPublisher(t *testing.T) { + mockReq := testReq{1} + mockRes := testRes{ + A: 1, + } + _, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + h := &mockHandler{} + + pub := NewPublisher( + h, + "test", + testReqEncoder, + testResMessageDecoder, + ) + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + res, pubErr := pub.Endpoint()(context.Background(), mockReq) + if pubErr != nil { + errChan <- pubErr + } else { + resChan <- res + } + }() + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.A, res.A; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +// TestSendAndForgetPublisher tests that the AsyncDeliverer is working +func TestAsyncPublisher(t *testing.T) { + finishChan := make(chan struct{}) + contextKey, contextValue := "test", "test" + h := &mockHandler{ + producePromise: func(ctx context.Context, record *kgo.Record, err error) { + if err != nil { + t.Fatal(err) + } + if want, have := contextValue, ctx.Value(contextKey); want != have { + t.Errorf("want %s, have %s", want, have) + } + finishChan <- struct{}{} + }, + } + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(ctx context.Context, rec *kgo.Record) (response interface{}, err error) { + val := ctx.Value(contextValue).(string) + assert.Equal(t, contextValue, val) + + return struct{}{}, nil + }, + PublisherDeliverer(AsyncDeliverer), + PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error) + go func() { + ctx := context.WithValue(context.Background(), contextKey, contextValue) + ctx, cancel := context.WithCancel(ctx) + cancel() + _, pubErr := pub.Endpoint()(ctx, struct{}{}) + if pubErr != nil { + errChan <- pubErr + } + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} + +func TestSetRequestID(t *testing.T) { + h := &mockHandler{} + + mockReq := testReq{1} + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherBefore(SetRequestID()), + PublisherDeliverer(func(ctx context.Context, publisher Publisher, message *kgo.Record) (*kgo.Record, error) { + r, err := sdkrequestidcontext.Extract(ctx) + require.NotNil(t, r) + require.Nil(t, err) + + return nil, nil + }), + ) + + _, err := pub.Endpoint()(context.Background(), mockReq) + require.Nil(t, err) +} + +func TestSetLogger(t *testing.T) { + h := &mockHandler{} + + mockReq := testReq{1} + + // Inject this "owned" buffer as Output in the logger wrapped by + // the loggingMiddleware under test. + var buffer bytes.Buffer + + config := &sdklogger.Config{ + ConsoleEnabled: true, + ConsoleJSONFormat: true, + ConsoleLevel: "info", + FileEnabled: false, + } + l, err := sdklogger.NewBuilder(config).BuildTestLogger(&buffer) + require.Nil(t, err) + + pub := NewPublisher( + h, + "test", + func(context.Context, *kgo.Record, interface{}) error { return nil }, + func(context.Context, *kgo.Record) (response interface{}, err error) { + return struct{}{}, nil + }, + PublisherBefore(SetLogger(l)), + PublisherDeliverer(func(ctx context.Context, publisher Publisher, message *kgo.Record) (*kgo.Record, error) { + l, ctxErr := sdkloggercontext.Extract(ctx) + require.NotNil(t, l) + require.Nil(t, ctxErr) + + l.Infof("test") + + return nil, nil + }), + PublisherAfter(func(ctx context.Context, ev *kgo.Record) context.Context { + l, ctxErr := sdkloggercontext.Extract(ctx) + require.NotNil(t, l) + require.Nil(t, ctxErr) + + return ctx + }), + ) + + _, err = pub.Endpoint()(context.Background(), mockReq) + require.Nil(t, err) + + var fields map[string]interface{} + err = json.Unmarshal(buffer.Bytes(), &fields) + require.Nil(t, err) + + assert.NotEmpty(t, fields["pubsub"]) + assert.NotEmpty(t, fields["dd"]) + + var pubsub = (fields["pubsub"]).(map[string]interface{}) + assert.NotNil(t, pubsub["request_id"]) +} + +func testReqEncoder(_ context.Context, m *kgo.Record, request interface{}) error { + req, ok := request.(testReq) + if !ok { + return errors.New("type assertion failure") + } + b, err := json.Marshal(req) + if err != nil { + return err + } + m.Value = b + return nil +} + +func testResMessageDecoder(_ context.Context, m *kgo.Record) (interface{}, error) { + return testResDecoder(m.Value) +} + +func testResDecoder(b []byte) (interface{}, error) { + var obj testRes + err := json.Unmarshal(b, &obj) + return obj, err +} diff --git a/pkg/transport/kafka/request_response_funcs.go b/pkg/transport/kafka/request_response_funcs.go new file mode 100644 index 0000000..9fdef25 --- /dev/null +++ b/pkg/transport/kafka/request_response_funcs.go @@ -0,0 +1,78 @@ +package kafka + +import ( + "context" + + "github.com/google/uuid" + "github.com/twmb/franz-go/pkg/kgo" + + sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" + sdkmetricscontext "github.com/scribd/go-sdk/pkg/context/metrics" + sdkrequestidcontext "github.com/scribd/go-sdk/pkg/context/requestid" + sdkinstrumentation "github.com/scribd/go-sdk/pkg/instrumentation" + sdklogger "github.com/scribd/go-sdk/pkg/logger" + sdkmetrics "github.com/scribd/go-sdk/pkg/metrics" +) + +// RequestFunc may take information from a Kafka message and put it into a +// request context. In Subscribers, RequestFuncs are executed prior to invoking the +// endpoint. +type RequestFunc func(ctx context.Context, msg *kgo.Record) context.Context + +// SubscriberResponseFunc may take information from a request context and use it to +// manipulate a Publisher. SubscriberResponseFuncs are only executed in +// consumers, after invoking the endpoint but prior to publishing a reply. +type SubscriberResponseFunc func(ctx context.Context, response interface{}) context.Context + +// PublisherResponseFunc may take information from a request context. +// PublisherResponseFunc are only executed in producers, after a request has been produced. +type PublisherResponseFunc func(ctx context.Context, msg *kgo.Record) context.Context + +// SetMetrics returns RequestFunc that sets the Metrics client to the request context. +func SetMetrics(m sdkmetrics.Metrics) RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + return sdkmetricscontext.ToContext(ctx, m) + } +} + +// SetRequestID returns RequestFunc that sets RequestID to the request context if not previously set. +func SetRequestID() RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + _, err := sdkrequestidcontext.Extract(ctx) + if err != nil { + if uuidObject, err := uuid.NewRandom(); err == nil { + requestID := uuidObject.String() + return sdkrequestidcontext.ToContext(ctx, requestID) + } + } + + return ctx + } +} + +// SetLogger returns RequestFunc that sets SDK Logger to the request context. +// It will also try to setup context values to the logger fields. +func SetLogger(l sdklogger.Logger) RequestFunc { + return func(ctx context.Context, msg *kgo.Record) context.Context { + logContext := sdkinstrumentation.TraceLogs(ctx) + + requestID, err := sdkrequestidcontext.Extract(ctx) + if err != nil { + l.WithFields(sdklogger.Fields{ + "error": err.Error(), + }).Tracef("Could not retrieve request id from the context") + } + + logger := l.WithFields(sdklogger.Fields{ + "pubsub": sdklogger.Fields{ + "request_id": requestID, + }, + "dd": sdklogger.Fields{ + "trace_id": logContext.TraceID, + "span_id": logContext.SpanID, + }, + }) + + return sdkloggercontext.ToContext(ctx, logger) + } +} diff --git a/pkg/transport/kafka/subscriber.go b/pkg/transport/kafka/subscriber.go new file mode 100644 index 0000000..17dac54 --- /dev/null +++ b/pkg/transport/kafka/subscriber.go @@ -0,0 +1,172 @@ +package kafka + +import ( + "context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" + "github.com/twmb/franz-go/pkg/kgo" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + + kafkasdk "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdkkafka "github.com/scribd/go-sdk/pkg/pubsub/kafka" +) + +// Subscriber wraps an endpoint and provides a handler for kafka messages. +type Subscriber struct { + e endpoint.Endpoint + dec DecodeRequestFunc + before []RequestFunc + after []SubscriberResponseFunc + finalizer []SubscriberFinalizerFunc + errorHandler transport.ErrorHandler + errorEncoder ErrorEncoder +} + +// NewSubscriber constructs a new subscriber provides a handler for kafka messages. +func NewSubscriber( + e endpoint.Endpoint, + dec DecodeRequestFunc, + opts ...SubscriberOption, +) *Subscriber { + c := &Subscriber{ + e: e, + dec: dec, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, opt := range opts { + opt(c) + } + + return c +} + +// SubscriberOption sets an optional parameter for subscribers. +type SubscriberOption func(consumer *Subscriber) + +// SubscriberBefore functions are executed on the subscriber message object +// before the request is decoded. +func SubscriberBefore(before ...RequestFunc) SubscriberOption { + return func(c *Subscriber) { + c.before = append(c.before, before...) + } +} + +// SubscriberAfter functions are executed on the subscriber reply after the +// endpoint is invoked, but before anything is published to the reply. +func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { + return func(c *Subscriber) { + c.after = append(c.after, after...) + } +} + +// SubscriberErrorEncoder is used to encode errors to the subscriber reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { + return func(s *Subscriber) { s.errorEncoder = ee } +} + +// SubscriberErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. +func SubscriberErrorHandler(errorHandler transport.ErrorHandler) SubscriberOption { + return func(c *Subscriber) { + c.errorHandler = errorHandler + } +} + +// SubscriberFinalizer is executed at the end of every message processing. +// By default, no finalizer is registered. +func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption { + return func(c *Subscriber) { + c.finalizer = append(c.finalizer, f...) + } +} + +// ServeMsg provides kafka.MsgHandler. +func (s Subscriber) ServeMsg(h Handler) sdkkafka.MsgHandler { + return func(msg *kgo.Record) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if len(s.finalizer) > 0 { + defer func() { + for _, f := range s.finalizer { + f(ctx, msg) + } + }() + } + + for _, f := range s.before { + ctx = f(ctx, msg) + } + + request, err := s.dec(ctx, msg) + if err != nil { + s.errorEncoder(ctx, err, msg, h) + s.errorHandler.Handle(ctx, err) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.errorEncoder(ctx, err, msg, h) + s.errorHandler.Handle(ctx, err) + return + } + + for _, f := range s.after { + ctx = f(ctx, response) + } + } +} + +// SubscriberFinalizerFunc can be used to perform work at the end of message processing, +// after the response has been constructed. The principal +// intended use is for request logging. +type SubscriberFinalizerFunc func(ctx context.Context, msg *kgo.Record) + +// ErrorEncoder is responsible for encoding an error to the subscriber reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, + err error, msg *kgo.Record, h Handler) + +// DefaultErrorEncoder simply ignores the message. +func DefaultErrorEncoder(ctx context.Context, + err error, msg *kgo.Record, h Handler) { +} + +// NewInstrumentedSubscriber constructs a new subscriber provides a handler for kafka messages. +// It also instruments the subscriber with datadog tracing. +func NewInstrumentedSubscriber(e endpoint.Endpoint, dec DecodeRequestFunc, opts ...SubscriberOption) *Subscriber { + options := []SubscriberOption{ + SubscriberBefore(startMessageHandlerTrace), + SubscriberFinalizer(finishMessageHandlerTrace), + } + + options = append(options, opts...) + + return NewSubscriber(e, dec, options...) +} + +func startMessageHandlerTrace(ctx context.Context, msg *kgo.Record) context.Context { + if spanctx, err := tracer.Extract(kafkasdk.NewMessageCarrier(msg)); err == nil { + span := tracer.StartSpan("kafka.msghandler", tracer.ChildOf(spanctx)) + + ctx = tracer.ContextWithSpan(ctx, span) + } + + return ctx +} + +func finishMessageHandlerTrace(ctx context.Context, msg *kgo.Record) { + if span, ok := tracer.SpanFromContext(ctx); ok { + span.Finish() + } +} diff --git a/pkg/transport/kafka/subscriber_test.go b/pkg/transport/kafka/subscriber_test.go new file mode 100644 index 0000000..482f400 --- /dev/null +++ b/pkg/transport/kafka/subscriber_test.go @@ -0,0 +1,113 @@ +package kafka + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/twmb/franz-go/pkg/kgo" +) + +// TestSubscriberBadDecode checks if decoder errors are handled properly. +func TestSubscriberBadDecode(t *testing.T) { + errCh := make(chan error, 1) + + sub := NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *kgo.Record) (interface{}, error) { return nil, errors.New("err!") }, + SubscriberErrorEncoder(createTestErrorEncoder(errCh)), + ) + + sub.ServeMsg(nil)(&kgo.Record{}) + + var err error + select { + case err = <-errCh: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadEndpoint checks if endpoint errors are handled properly. +func TestSubscriberBadEndpoint(t *testing.T) { + errCh := make(chan error, 1) + + sub := NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *kgo.Record) (interface{}, error) { return struct{}{}, nil }, + SubscriberErrorEncoder(createTestErrorEncoder(errCh)), + ) + + sub.ServeMsg(nil)(&kgo.Record{}) + + var err error + select { + case err = <-errCh: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSubscriberSuccess(t *testing.T) { + obj := testReq{A: 1} + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + + sub := NewSubscriber( + testEndpoint, + testReqDecoder, + SubscriberAfter(func(ctx context.Context, response interface{}) context.Context { + res := response.(testRes) + if res.A != 2 { + t.Errorf("got wrong result: %d", res.A) + } + + return ctx + }), + ) + + // golangci-lint v1.44.2 in linux reports that variable 'b' declared but unused + // in case of usage of self-invoked function. Probably need an investigation + handler := sub.ServeMsg(nil) + handler(&kgo.Record{ + Value: b, + }) +} + +func createTestErrorEncoder(ch chan error) ErrorEncoder { + return func(ctx context.Context, err error, msg *kgo.Record, h Handler) { + ch <- err + } +} + +func testReqDecoder(_ context.Context, m *kgo.Record) (interface{}, error) { + var obj testReq + err := json.Unmarshal(m.Value, &obj) + return obj, err +} + +func testEndpoint(_ context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, errors.New("type assertion error") + } + + res := testRes{ + A: req.A + 1, + } + return res, nil +}