diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index f47d0ca2b..965be4498 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -154,7 +154,7 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan wg.Add(1) go func() { - subscriber.sendMessageToSubscriber(message, logFields) + subscriber.sendMessageToSubscriber(message, logFields, g.config.BlockPublishUntilSubscriberAck) wg.Done() }() } @@ -236,7 +236,7 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag msg := g.persistedMessages[topic][i] logFields := watermill.LogFields{"message_uuid": msg.UUID, "topic": topic} - go s.sendMessageToSubscriber(msg, logFields) + go s.sendMessageToSubscriber(msg, logFields, g.config.BlockPublishUntilSubscriberAck) } } @@ -341,7 +341,7 @@ func (s *subscriber) Close() { close(s.outputChannel) } -func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields) { +func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields, blockPublishUntilSubscriberAck bool) { s.sending.Lock() defer s.sending.Unlock() @@ -370,6 +370,11 @@ SendToSubscriber: return } + if !blockPublishUntilSubscriberAck { + s.logger.Trace("Sent message to subscriber without ack", logFields) + return + } + select { case <-msgToSend.Acked(): s.logger.Trace("Message acked", logFields) diff --git a/pubsub/gochannel/pubsub_test.go b/pubsub/gochannel/pubsub_test.go index de2dfb5b6..1b5322e0f 100644 --- a/pubsub/gochannel/pubsub_test.go +++ b/pubsub/gochannel/pubsub_test.go @@ -108,6 +108,72 @@ func TestPublishSubscribe_block_until_ack(t *testing.T) { } } +func TestPublishSubscribe_do_not_block_without_ack_required(t *testing.T) { + t.Helper() + + messagesCount := 10 + subscribersCount := 10 + + pubSub := gochannel.NewGoChannel( + gochannel.Config{ + OutputChannelBuffer: int64(messagesCount), + Persistent: true, + }, + watermill.NewStdLogger(true, false), + ) + + allSent := sync.WaitGroup{} + allSent.Add(messagesCount) + allReceived := sync.WaitGroup{} + + sentMessages := message.Messages{} + subscriberReceivedCh := make(chan message.Messages, subscribersCount) + for i := 0; i < subscribersCount; i++ { + allReceived.Add(1) + + go func(i int) { + allMsgReceived := make(message.Messages, 0) + msgs, err := pubSub.Subscribe(context.Background(), "topic") + require.NoError(t, err) + + for received := range msgs { + allMsgReceived = append(allMsgReceived, received) + if len(allMsgReceived) >= len(sentMessages) { + break + } + } + subscriberReceivedCh <- allMsgReceived + allReceived.Done() + }(i) + } + + go func() { + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(watermill.NewUUID(), nil) + sentMessages = append(sentMessages, msg) + + go func() { + require.NoError(t, pubSub.Publish("topic", msg)) + allSent.Done() + }() + } + }() + + log.Println("waiting for all sent") + allSent.Wait() + + log.Println("waiting for all received") + allReceived.Wait() + + close(subscriberReceivedCh) + + log.Println("asserting") + + for subMsgs := range subscriberReceivedCh { + tests.AssertAllMessagesReceived(t, sentMessages, subMsgs) + } +} + func TestPublishSubscribe_race_condition_on_subscribe(t *testing.T) { testsCount := 15 if testing.Short() { @@ -228,19 +294,6 @@ func testPublishSubscribeSubRace(t *testing.T) { allSent.Add(messagesCount) allReceived := sync.WaitGroup{} - sentMessages := message.Messages{} - go func() { - for i := 0; i < messagesCount; i++ { - msg := message.NewMessage(watermill.NewUUID(), nil) - sentMessages = append(sentMessages, msg) - - go func() { - require.NoError(t, pubSub.Publish("topic", msg)) - allSent.Done() - }() - } - }() - subscriberReceivedCh := make(chan message.Messages, subscribersCount) for i := 0; i < subscribersCount; i++ { allReceived.Add(1) @@ -256,6 +309,19 @@ func testPublishSubscribeSubRace(t *testing.T) { }() } + sentMessages := message.Messages{} + go func() { + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(watermill.NewUUID(), nil) + sentMessages = append(sentMessages, msg) + + go func() { + require.NoError(t, pubSub.Publish("topic", msg)) + allSent.Done() + }() + } + }() + log.Println("waiting for all sent") allSent.Wait() diff --git a/pubsub/tests/test_pubsub.go b/pubsub/tests/test_pubsub.go index 18de3acc8..8e489efcb 100644 --- a/pubsub/tests/test_pubsub.go +++ b/pubsub/tests/test_pubsub.go @@ -477,12 +477,12 @@ func TestResendOnError( var publishedMessages message.Messages allMessagesSent := make(chan struct{}) - publishedMessages = PublishSimpleMessages(t, messagesToSend, pub, topicName) - close(allMessagesSent) - messages, err := sub.Subscribe(context.Background(), topicName) require.NoError(t, err) + publishedMessages = PublishSimpleMessages(t, messagesToSend, pub, topicName) + close(allMessagesSent) + NackLoop: for i := 0; i < nacksCount; i++ { select { @@ -524,6 +524,9 @@ func TestNoAck( require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) } + messages, err := sub.Subscribe(context.Background(), topicName) + require.NoError(t, err) + for i := 0; i < 2; i++ { id := watermill.NewUUID() log.Printf("sending %s", id) @@ -534,9 +537,6 @@ func TestNoAck( require.NoError(t, err) } - messages, err := sub.Subscribe(context.Background(), topicName) - require.NoError(t, err) - receivedMessage := make(chan struct{}) unlockAck := make(chan struct{}, 1) go func() {