From eba9a3f53457ae54f6594ae90921a279505dbfd8 Mon Sep 17 00:00:00 2001 From: Maksim Terekhin Date: Tue, 16 Apr 2024 16:09:55 +0200 Subject: [PATCH] feat: Add PubSub Kafka client --- pkg/pubsub/kafka/config.go | 154 ++++++++++ pkg/pubsub/kafka/partitionconsumer.go | 42 +++ pkg/pubsub/kafka/pool.go | 33 +++ pkg/pubsub/kafka/publisher.go | 85 ++++++ pkg/pubsub/kafka/subscriber.go | 214 ++++++++++++++ pkg/pubsub/kafka/subscriber_test.go | 404 ++++++++++++++++++++++++++ 6 files changed, 932 insertions(+) create mode 100644 pkg/pubsub/kafka/config.go create mode 100644 pkg/pubsub/kafka/partitionconsumer.go create mode 100644 pkg/pubsub/kafka/pool.go create mode 100644 pkg/pubsub/kafka/publisher.go create mode 100644 pkg/pubsub/kafka/subscriber.go create mode 100644 pkg/pubsub/kafka/subscriber_test.go diff --git a/pkg/pubsub/kafka/config.go b/pkg/pubsub/kafka/config.go new file mode 100644 index 0000000..d043d45 --- /dev/null +++ b/pkg/pubsub/kafka/config.go @@ -0,0 +1,154 @@ +package kafka + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/twmb/franz-go/pkg/kgo" + awssasl "github.com/twmb/franz-go/pkg/sasl/aws" + "github.com/twmb/franz-go/pkg/sasl/plain" + + sdklogger "github.com/scribd/go-sdk/pkg/logger" + "github.com/scribd/go-sdk/pkg/pubsub" +) + +// Config provides a common configuration for Kafka PubSub clients. +type Config struct { + // Application name that will be used in a serviceName provided to tracer spans + ApplicationName string + // Kafka configuration provided by go-sdk + KafkaConfig pubsub.Kafka + // AWS session reference, it will be used in case AWS MSK IAM authentication mechanism is used + AwsSession *session.Session + // MsgHandler is a function that will be called when a message is received + MsgHandler MsgHandler + Logger sdklogger.Logger +} + +const tlsConnectionTimeout = 10 * time.Second + +func newConfig(c Config, opts ...kgo.Opt) ([]kgo.Opt, error) { + options := []kgo.Opt{ + kgo.SeedBrokers(c.KafkaConfig.BrokerUrls...), + kgo.ClientID(c.KafkaConfig.ClientId), + } + + if c.KafkaConfig.SASL.Enabled { + switch c.KafkaConfig.SASLMechanism() { + case pubsub.Plain: + options = append(options, getPlainSaslOption(c.KafkaConfig.SASL)) + case pubsub.AWSMskIam: + options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession)) + } + } + + if c.KafkaConfig.TLS.Enabled || c.KafkaConfig.SecurityProtocol == "ssl" { + var caCertPool *x509.CertPool + + if c.KafkaConfig.TLS.Ca != "" { + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM([]byte(c.KafkaConfig.TLS.Ca)) + } + + var certificates []tls.Certificate + if c.KafkaConfig.TLS.Cert != "" && c.KafkaConfig.TLS.CertKey != "" { + cert, err := tls.X509KeyPair([]byte(c.KafkaConfig.TLS.Cert), []byte(c.KafkaConfig.TLS.CertKey)) + if err != nil { + return nil, err + } + certificates = []tls.Certificate{cert} + } + + if c.KafkaConfig.Cert != "" && c.KafkaConfig.CertKey != "" { + cert, err := tls.X509KeyPair([]byte(c.KafkaConfig.Cert), []byte(c.KafkaConfig.CertKey)) + if err != nil { + return nil, err + } + certificates = []tls.Certificate{cert} + } + + var skipTLSVerify bool + if c.KafkaConfig.TLS.InsecureSkipVerify || !c.KafkaConfig.SSLVerificationEnabled { + skipTLSVerify = true + } + + tlsDialer := &tls.Dialer{ + NetDialer: &net.Dialer{Timeout: tlsConnectionTimeout}, + Config: &tls.Config{ + InsecureSkipVerify: skipTLSVerify, + Certificates: certificates, + RootCAs: caCertPool, + }, + } + + options = append(options, kgo.Dialer(tlsDialer.DialContext)) + } + + options = append(options, opts...) + + return options, nil +} + +func getPlainSaslOption(saslConf pubsub.SASL) kgo.Opt { + return kgo.SASL(plain.Auth{ + User: saslConf.Username, + Pass: saslConf.Password, + }.AsMechanism()) +} + +func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kgo.Opt { + var opt kgo.Opt + + // no AWS session provided + if s == nil { + opt = kgo.SASL(awssasl.Auth{ + AccessKey: iamConf.AccessKey, + SecretKey: iamConf.SecretKey, + SessionToken: iamConf.SessionToken, + UserAgent: iamConf.UserAgent, + }.AsManagedStreamingIAMMechanism()) + } else { + opt = kgo.SASL( + awssasl.ManagedStreamingIAM(func(ctx context.Context) (awssasl.Auth, error) { + // If assumable role is not provided, we try to get credentials from the provided AWS session + if iamConf.AssumableRole == "" { + val, err := s.Config.Credentials.Get() + if err != nil { + return awssasl.Auth{}, err + } + + return awssasl.Auth{ + AccessKey: val.AccessKeyID, + SecretKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + UserAgent: iamConf.UserAgent, + }, nil + } + + svc := sts.New(s) + + res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{ + RoleArn: &iamConf.AssumableRole, + RoleSessionName: &iamConf.SessionName, + }) + if stsErr != nil { + return awssasl.Auth{}, stsErr + } + + return awssasl.Auth{ + AccessKey: *res.Credentials.AccessKeyId, + SecretKey: *res.Credentials.SecretAccessKey, + SessionToken: *res.Credentials.SessionToken, + UserAgent: iamConf.UserAgent, + }, nil + }), + ) + } + + return opt +} diff --git a/pkg/pubsub/kafka/partitionconsumer.go b/pkg/pubsub/kafka/partitionconsumer.go new file mode 100644 index 0000000..824f66c --- /dev/null +++ b/pkg/pubsub/kafka/partitionconsumer.go @@ -0,0 +1,42 @@ +package kafka + +import ( + "context" + + "github.com/twmb/franz-go/pkg/kgo" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type pconsumer struct { + pool *pool + + quit chan struct{} + done chan struct{} + recs chan *sdkkafka.FetchPartition +} + +func (pc *pconsumer) consume(cl *kgo.Client, logger sdklogger.Logger, shouldCommit bool, handler func(*kgo.Record)) { + defer close(pc.done) + + for { + select { + case <-pc.quit: + return + case p := <-pc.recs: + p.EachRecord(func(rec *kgo.Record) { + pc.pool.Schedule(func() { + defer p.ConsumeRecord(rec) + + handler(rec) + }) + }) + if shouldCommit { + if err := cl.CommitRecords(context.Background(), p.Records...); err != nil { + logger.WithError(err).Errorf("Partition consumer failed to commit records") + } + } + } + } +} diff --git a/pkg/pubsub/kafka/pool.go b/pkg/pubsub/kafka/pool.go new file mode 100644 index 0000000..e80ce1b --- /dev/null +++ b/pkg/pubsub/kafka/pool.go @@ -0,0 +1,33 @@ +package kafka + +type pool struct { + sem chan struct{} + work chan func() +} + +func newPool(size int) *pool { + p := &pool{ + sem: make(chan struct{}, size), + work: make(chan func()), + } + + return p +} + +func (p *pool) Schedule(task func()) { + select { + case p.work <- task: + return + case p.sem <- struct{}{}: + go p.worker(task) + } +} + +func (p *pool) worker(task func()) { + defer func() { <-p.sem }() + + for { + task() + task = <-p.work + } +} diff --git a/pkg/pubsub/kafka/publisher.go b/pkg/pubsub/kafka/publisher.go new file mode 100644 index 0000000..69377c1 --- /dev/null +++ b/pkg/pubsub/kafka/publisher.go @@ -0,0 +1,85 @@ +package kafka + +import ( + "context" + "fmt" + "time" + + "github.com/twmb/franz-go/pkg/kgo" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" +) + +type ( + Publisher struct { + producer *sdkkafka.Client + } +) + +const ( + defaultFlushTimeout = time.Second * 10 + + publisherServiceNameSuffix = "pubsub-publisher" +) + +// NewPublisher is a tiny wrapper around the go-sdk kafka.Client and provides API to Publish kafka messages. +func NewPublisher(c Config, opts ...kgo.Opt) (*Publisher, error) { + serviceName := fmt.Sprintf("%s-%s", c.ApplicationName, publisherServiceNameSuffix) + + cfg, err := newConfig(c, opts...) + if err != nil { + return nil, err + } + + cfg = append(cfg, []kgo.Opt{ + kgo.ProduceRequestTimeout(c.KafkaConfig.Publisher.WriteTimeout), + kgo.RecordRetries(c.KafkaConfig.Publisher.MaxAttempts), + }...) + + producer, err := sdkkafka.NewClient(cfg, sdkkafka.WithServiceName(serviceName)) + if err != nil { + return nil, err + } + + return &Publisher{producer: producer}, nil +} + +// Publish publishes kgo.Record message. +func (p *Publisher) Publish(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) { + p.producer.Produce(ctx, rec, fn) +} + +// Produce is an alias to Publish to satisfy kafka go-kit transport. +func (p *Publisher) Produce(ctx context.Context, rec *kgo.Record, fn func(record *kgo.Record, err error)) { + p.Publish(ctx, rec, fn) +} + +// ProduceSync publishes kgo.Record messages synchronously. +func (p *Publisher) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + return p.producer.ProduceSync(ctx, rs...) +} + +// GetKafkaProducer returns underlying kafka.Producer for fine-grained tuning purposes. +func (p *Publisher) GetKafkaProducer() *sdkkafka.Client { + return p.producer +} + +// Stop flushes and waits for outstanding messages and requests to complete delivery. +// It also closes a Producer instance. +func (p *Publisher) Stop(ctx context.Context) error { + if _, deadlineSet := ctx.Deadline(); !deadlineSet { + timeoutCtx, cancel := context.WithTimeout(ctx, defaultFlushTimeout) + defer cancel() + + ctx = timeoutCtx + } + + err := p.producer.Flush(ctx) + if err != nil { + return err + } + + p.producer.Close() + + return nil +} diff --git a/pkg/pubsub/kafka/subscriber.go b/pkg/pubsub/kafka/subscriber.go new file mode 100644 index 0000000..d256f8e --- /dev/null +++ b/pkg/pubsub/kafka/subscriber.go @@ -0,0 +1,214 @@ +package kafka + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/twmb/franz-go/pkg/kerr" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/pkg/kmsg" + + sdkkafka "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type ( + Subscriber struct { + logger sdklogger.Logger + consumer *sdkkafka.Client + autoCommitDisabled bool + mu sync.Mutex + consumers map[string]map[int32]pconsumer + handler func(rec *kgo.Record) + numWorkers int + maxRecords int + } + MsgHandler func(msg *kgo.Record) +) + +const ( + subscriberServiceNameSuffix = "pubsub-subscriber" + + defaultMaxRecords = 10000 +) + +// NewSubscriber is a tiny wrapper around the sdk kafka.Client and provides API to subscribe to a kafka topic. +func NewSubscriber(c Config, opts ...kgo.Opt) (*Subscriber, error) { + serviceName := fmt.Sprintf("%s-%s", c.ApplicationName, subscriberServiceNameSuffix) + + cfg, err := newConfig(c, opts...) + if err != nil { + return nil, err + } + + autoCommitDisabled := !c.KafkaConfig.Subscriber.AutoCommit.Enabled + if autoCommitDisabled { + cfg = append(cfg, kgo.DisableAutoCommit()) + } + if c.KafkaConfig.Subscriber.BlockRebalance { + cfg = append(cfg, kgo.BlockRebalanceOnPoll()) + } + + s := &Subscriber{ + logger: c.Logger, + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: c.KafkaConfig.Subscriber.Workers, + handler: c.MsgHandler, + autoCommitDisabled: autoCommitDisabled, + maxRecords: c.KafkaConfig.Subscriber.MaxRecords, + } + + if s.maxRecords == 0 { + s.maxRecords = defaultMaxRecords + } + + cfg = append(cfg, []kgo.Opt{ + kgo.ConsumerGroup(c.KafkaConfig.Subscriber.GroupId), + kgo.ConsumeTopics(c.KafkaConfig.Subscriber.Topic), + kgo.OnPartitionsLost(s.lost), + kgo.OnPartitionsRevoked(s.revoked), + kgo.OnPartitionsAssigned(s.assigned), + }...) + + client, err := sdkkafka.NewClient(cfg, sdkkafka.WithServiceName(serviceName)) + if err != nil { + return nil, err + } + s.consumer = client + + return s, nil +} + +// Subscribe subscribes to a configured topic and reads messages. +// Returns unbuffered channel to inspect possible errors. +func (s *Subscriber) Subscribe(ctx context.Context) chan error { + ch := make(chan error) + + go func() { + defer close(ch) + + for { + fetches := s.consumer.PollRecords(ctx, s.maxRecords) + if fetches.IsClientClosed() { + return + } + + if fetches.Err() != nil { + var containsFatalErr bool + + fetches.EachError(func(_ string, _ int32, err error) { + if !containsFatalErr { + containsFatalErr = isFatalFetchError(err) + } + ch <- err + }) + + if containsFatalErr { + return + } + } + + fetches.EachTopic(func(t kgo.FetchTopic) { + s.mu.Lock() + tconsumers := s.consumers[t.Topic] + s.mu.Unlock() + + if tconsumers == nil { + return + } + t.EachPartition(func(p kgo.FetchPartition) { + pc, ok := tconsumers[p.Partition] + if !ok { + return + } + select { + case pc.recs <- s.consumer.WrapFetchPartition(ctx, p): + case <-pc.quit: + } + }) + }) + if kgoClient, ok := s.consumer.KafkaClient.(*kgo.Client); ok { + // this call does nothing in case rebalance is not blocked + kgoClient.AllowRebalance() + } + } + }() + + return ch +} + +func (s *Subscriber) assigned(_ context.Context, cl *kgo.Client, assigned map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + for topic, partitions := range assigned { + if s.consumers[topic] == nil { + s.consumers[topic] = make(map[int32]pconsumer) + } + for _, partition := range partitions { + pc := pconsumer{ + quit: make(chan struct{}), + recs: make(chan *sdkkafka.FetchPartition), + pool: newPool(s.numWorkers), + done: make(chan struct{}), + } + s.consumers[topic][partition] = pc + go pc.consume(cl, s.logger, s.autoCommitDisabled, s.handler) + } + } +} + +func (s *Subscriber) stopConsumers(lost map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + var wg sync.WaitGroup + defer wg.Wait() + + for topic, partitions := range lost { + ptopics := s.consumers[topic] + for _, partition := range partitions { + pc := ptopics[partition] + delete(ptopics, partition) + + if len(ptopics) == 0 { + delete(s.consumers, topic) + } + close(pc.quit) + wg.Add(1) + go func() { <-pc.done; wg.Done() }() + } + } +} + +func (s *Subscriber) lost(_ context.Context, _ *kgo.Client, lost map[string][]int32) { + s.stopConsumers(lost) +} + +func (s *Subscriber) revoked(ctx context.Context, cl *kgo.Client, lost map[string][]int32) { + s.stopConsumers(lost) + if !s.autoCommitDisabled { + cl.CommitOffsetsSync(ctx, cl.MarkedOffsets(), + func(cl *kgo.Client, _ *kmsg.OffsetCommitRequest, _ *kmsg.OffsetCommitResponse, err error) { + if err != nil { + s.logger.WithError(err).Errorf("Revoke commit failed") + } + }, + ) + } +} + +func (s *Subscriber) Unsubscribe() error { + s.consumer.Close() + + return nil +} + +func isFatalFetchError(err error) bool { + var kafkaErr *kerr.Error + + return errors.As(err, &kafkaErr) +} diff --git a/pkg/pubsub/kafka/subscriber_test.go b/pkg/pubsub/kafka/subscriber_test.go new file mode 100644 index 0000000..5e72c6b --- /dev/null +++ b/pkg/pubsub/kafka/subscriber_test.go @@ -0,0 +1,404 @@ +package kafka + +import ( + "bytes" + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/twmb/franz-go/pkg/kerr" + "github.com/twmb/franz-go/pkg/kgo" + + "github.com/scribd/go-sdk/pkg/instrumentation/kafka" + sdklogger "github.com/scribd/go-sdk/pkg/logger" +) + +type mockKafkaClient struct { + fetches chan kgo.Fetches +} + +func (c *mockKafkaClient) Produce(ctx context.Context, r *kgo.Record, promise func(*kgo.Record, error)) { +} + +func (c *mockKafkaClient) ProduceSync(ctx context.Context, rs ...*kgo.Record) kgo.ProduceResults { + return nil +} + +func (c *mockKafkaClient) PollRecords(ctx context.Context, num int) kgo.Fetches { + return <-c.fetches +} + +func (c *mockKafkaClient) Flush(ctx context.Context) error { + return nil +} + +func (c *mockKafkaClient) Close() { + close(c.fetches) +} + +func TestSubscriber_Subscribe(t *testing.T) { + tests := []struct { + name string + fetches kgo.Fetches + wantErr bool + msgHandler func(record *kgo.Record) + wantTerminated bool + expectedErrCnt int + }{ + { + name: "success", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Records: []*kgo.Record{{Key: []byte("test")}}, + }, + }, + }, + }, + }, + }, + msgHandler: func(record *kgo.Record) { + assert.Equal(t, []byte("test"), record.Key) + }, + }, + { + name: "non-fatal error", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + }, + }, + }, + }, + }, + expectedErrCnt: 1, + wantErr: true, + }, + { + name: "non-fatal errors", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + { + Err: errors.New("test2"), + }, + }, + }, + }, + }, + }, + expectedErrCnt: 2, + }, + { + name: "fatal error (terminate subscriber)", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: kerr.BrokerIDNotRegistered, + }, + }, + }, + }, + }, + }, + expectedErrCnt: 1, + wantErr: true, + wantTerminated: true, + }, + { + name: "fatal error (closed client)", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Partition: 0, + Err: kgo.ErrClientClosed, + }, + }, + }, + }, + }, + }, + wantTerminated: true, + }, + { + name: "fatal and non-fatal errors", + fetches: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Err: errors.New("test"), + }, + { + Err: kerr.BrokerIDNotRegistered, + }, + }, + }, + }, + }, + }, + wantErr: true, + expectedErrCnt: 2, + wantTerminated: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fetchesChan := make(chan kgo.Fetches, 1) + fetchesChan <- tt.fetches + + var wg sync.WaitGroup + wg.Add(1) + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: 1, + handler: func(rec *kgo.Record) { + defer wg.Done() + + tt.msgHandler(rec) + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), nil, getAssigns(1)) + + errCh := s.Subscribe(context.Background()) + if tt.msgHandler != nil { + wg.Wait() + } + + if tt.wantErr { + i := 0 + for err := range errCh { + i++ + assert.NotNil(t, err) + if i == tt.expectedErrCnt && !tt.wantTerminated { + break + } + } + assert.Equal(t, tt.expectedErrCnt, i) + } + + if tt.wantTerminated { + _, ok := <-errCh + assert.False(t, ok) + + err := s.Unsubscribe() + assert.Nil(t, err) + + _, ok = <-fetchesChan + assert.False(t, ok) + } + }) + } +} + +func TestSubscriber_ConcurrentSubscribe(t *testing.T) { + tests := []struct { + name string + numWorkers int + numPartitions int + }{ + { + name: "1 worker, 1 partition", + numWorkers: 1, + numPartitions: 1, + }, { + name: "1 worker, 10 partitions", + numWorkers: 1, + numPartitions: 10, + }, { + name: "10 workers, 10 partitions", + numWorkers: 10, + numPartitions: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fetches := generateFetches(tt.numPartitions) + fetchesChan := make(chan kgo.Fetches, len(fetches)) + fetchesChan <- fetches + + var wg sync.WaitGroup + wg.Add(100) + + var i uint64 + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + numWorkers: tt.numWorkers, + handler: func(rec *kgo.Record) { + atomic.AddUint64(&i, 1) + wg.Done() + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), nil, getAssigns(tt.numPartitions)) + + s.Subscribe(context.Background()) + + wg.Wait() + + assert.Equal(t, 100, int(i)) + }) + } +} + +func TestSubscriber_RevokePartition(t *testing.T) { + fetchesChan := make(chan kgo.Fetches, 1) + + fetch := kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + Partitions: []kgo.FetchPartition{ + { + Records: []*kgo.Record{{Key: []byte("test")}}, + }, + }, + }, + }, + }, + } + + fetchesChan <- fetch + + recv := make(chan struct{}, 1) + + l, err := logger() + + assert.Nil(t, err) + + s := &Subscriber{ + mu: sync.Mutex{}, + consumers: make(map[string]map[int32]pconsumer), + logger: l, + numWorkers: 1, + handler: func(rec *kgo.Record) { + recv <- struct{}{} + }, + consumer: kafka.WrapClient(&kafka.Client{ + KafkaClient: &mockKafkaClient{ + fetches: fetchesChan, + }, + }), + } + + s.assigned(context.Background(), &kgo.Client{}, getAssigns(1)) + + s.Subscribe(context.Background()) + + select { + case <-recv: + case <-time.After(5 * time.Millisecond): + t.Fatal("expected to receive a message") + } + + // make sure we stop the partition consumer + s.revoked(context.Background(), &kgo.Client{}, getAssigns(1)) + + fetchesChan <- fetch + + select { + case <-recv: + t.Fatal("expected to consumer to be stopped") + case <-time.After(5 * time.Millisecond): + } + + err = s.Unsubscribe() + assert.Nil(t, err) +} + +func getAssigns(numPartitions int) map[string][]int32 { + assigns := make(map[string][]int32) + for i := 0; i < numPartitions; i++ { + assigns["test"] = append(assigns["test"], int32(i)) + } + return assigns +} + +func generateFetches(partitions int) kgo.Fetches { + fetches := kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: "test", + }, + }, + }, + } + + numRecords := 100 / partitions + for i := 0; i < partitions; i++ { + records := make([]*kgo.Record, numRecords) + for j := 0; j < numRecords; j++ { + records[j] = &kgo.Record{ + Value: []byte(fmt.Sprintf("test %d", j)), + } + } + + fetches[0].Topics[0].Partitions = append(fetches[0].Topics[0].Partitions, kgo.FetchPartition{ + Partition: int32(i), + Records: records, + }) + } + + return fetches +} + +func logger() (sdklogger.Logger, error) { + var buffer bytes.Buffer + return sdklogger.NewBuilder( + &sdklogger.Config{ + ConsoleEnabled: true, + ConsoleJSONFormat: true, + ConsoleLevel: "info", + FileEnabled: false, + }).BuildTestLogger(&buffer) +}