Skip to content

Commit

Permalink
Add unit tests for the GetRecords function (#64)
Browse files Browse the repository at this point in the history
A closed shard returns the last sequence number and no error. The current implementation leads to an infinite loop if the shard is closed. NextShardIterator checking is enough. That's why I remove the getShardIterator call
  • Loading branch information
Vincent authored and harlow committed Jul 28, 2018
1 parent a123922 commit 049445e
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 9 deletions.
12 changes: 3 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, las
}
continue
}

for _, r := range resp.Records {
select {
case <-ctx.Done():
Expand All @@ -117,16 +116,11 @@ func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, las
lastSeqNum = *r.SequenceNumber
}
}

if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
} else {
shardIterator = resp.NextShardIterator
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
shardIterator = resp.NextShardIterator
}
}
}()
Expand Down
150 changes: 150 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package consumer_test

import (
"testing"

"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/harlow/kinesis-consumer"
)

func TestKinesisClient_GetRecords_SuccessfullyRun(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*kinesis.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("New kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, errorsChan, err := c.GetRecords(ctx, "myStream", "shardId-000000000000", "")

if recordsChan == nil {
t.Errorf("records channel expected not nil, got %v", recordsChan)
}
if errorsChan == nil {
t.Errorf("errors channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Errorf("error expected nil, got %v", err)
}

cancelFunc()
}

func TestKinesisClient_GetRecords_SuccessfullyRetrievesThreeRecordsAtOnce(t *testing.T) {
expectedResults := []*kinesis.Record{
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576195"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576196"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576197"),
}}
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: expectedResults,
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, _, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")

if recordsChan == nil {
t.Fatalf("records channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Fatalf("error expected nil, got %v", err)
}
var results []*consumer.Record
results = append(results, <-recordsChan, <-recordsChan, <-recordsChan)
if len(results) != 3 {
t.Errorf("number of records expected 3, got %v", len(results))
}
for i, r := range results {
if r != expectedResults[i] {
t.Errorf("record expected %v, got %v", expectedResults[i], r)
}
}

cancelFunc()
}

func TestKinesisClient_GetRecords_ShardIsClosed(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*consumer.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
_, errorsChan, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")

if errorsChan == nil {
t.Fatalf("errors channel expected equals not nil, got %v", errorsChan)
}
if err != nil {
t.Fatalf("error expected, got %v", err)
}

err = <-errorsChan
if err == nil {
t.Errorf("error expected, got %v", err)
}

cancelFunc()
}

type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
}

func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
}

func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}

0 comments on commit 049445e

Please sign in to comment.