diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/Consumer.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/Consumer.java index 83dcd8d6c1616..4cd54420200be 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/Consumer.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/Consumer.java @@ -324,7 +324,7 @@ public Future sendMessages(final List entries, EntryBatch if (pendingAcks != null) { int batchSize = batchSizes.getBatchSize(i); int stickyKeyHash = getStickyKeyHash(entry); - long[] ackSet = getCursorAckSet(PositionImpl.get(entry.getLedgerId(), entry.getEntryId())); + long[] ackSet = batchIndexesAcks == null ? null : batchIndexesAcks.getAckSet(i); if (ackSet != null) { unackedMessages -= (batchSize - BitSet.valueOf(ackSet).cardinality()); } diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/BatchMessageWithBatchIndexLevelTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/BatchMessageWithBatchIndexLevelTest.java index b2fbe824b3305..3a4cee7f2be83 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/BatchMessageWithBatchIndexLevelTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/BatchMessageWithBatchIndexLevelTest.java @@ -18,8 +18,17 @@ */ package org.apache.pulsar.broker.service; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; +import com.carrotsearch.hppc.ObjectSet; +import java.lang.reflect.Field; import java.util.ArrayList; import java.util.List; import java.util.UUID; @@ -28,10 +37,14 @@ import lombok.Cleanup; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; +import org.apache.bookkeeper.mledger.Entry; +import org.apache.bookkeeper.mledger.impl.ManagedCursorImpl; +import org.apache.bookkeeper.mledger.impl.PositionImpl; import org.apache.pulsar.broker.BrokerTestUtil; import org.apache.pulsar.broker.service.persistent.PersistentDispatcherMultipleConsumers; import org.apache.pulsar.broker.service.persistent.PersistentTopic; import org.apache.pulsar.client.api.Consumer; +import org.apache.pulsar.client.api.ConsumerBuilder; import org.apache.pulsar.client.api.Message; import org.apache.pulsar.client.api.MessageId; import org.apache.pulsar.client.api.MessageIdAdv; @@ -39,8 +52,10 @@ import org.apache.pulsar.client.api.Schema; import org.apache.pulsar.client.api.SubscriptionInitialPosition; import org.apache.pulsar.client.api.SubscriptionType; +import org.apache.pulsar.client.impl.BatchMessageIdImpl; import org.apache.pulsar.client.impl.ConsumerImpl; import org.apache.pulsar.common.util.FutureUtil; +import org.apache.pulsar.common.util.collections.BitSetRecyclable; import org.awaitility.Awaitility; import org.testng.Assert; import org.testng.annotations.BeforeClass; @@ -401,4 +416,171 @@ public void testMixIndexAndNonIndexUnAckMessageCount() throws Exception { assertEquals(admin.topics().getStats(topicName).getSubscriptions() .get("sub").getUnackedMessages(), 0); } + + @Test + public void testUnAckMessagesWhenConcurrentDeliveryAndAck() throws Exception { + final String topicName = BrokerTestUtil.newUniqueName("persistent://prop/ns-abc/tp"); + final String subName = "s1"; + final int receiverQueueSize = 500; + admin.topics().createNonPartitionedTopic(topicName); + admin.topics().createSubscription(topicName, subName, MessageId.earliest); + ConsumerBuilder consumerBuilder = pulsarClient.newConsumer(Schema.STRING) + .topic(topicName) + .receiverQueueSize(receiverQueueSize) + .subscriptionName(subName) + .enableBatchIndexAcknowledgment(true) + .subscriptionType(SubscriptionType.Shared) + .isAckReceiptEnabled(true); + + // Send 100 messages. + Producer producer = pulsarClient.newProducer(Schema.STRING) + .topic(topicName) + .enableBatching(true) + .batchingMaxPublishDelay(1, TimeUnit.HOURS) + .create(); + CompletableFuture lastSent = null; + for (int i = 0; i < 100; i++) { + lastSent = producer.sendAsync(i + ""); + } + producer.flush(); + lastSent.join(); + + // When consumer1 is closed, may some messages are in the client memory(it they are being acked now). + Consumer consumer1 = consumerBuilder.consumerName("c1").subscribe(); + Message[] messagesInClientMemory = new Message[2]; + for (int i = 0; i < 2; i++) { + Message msg = consumer1.receive(2, TimeUnit.SECONDS); + assertNotNull(msg); + messagesInClientMemory[i] = msg; + } + ConsumerImpl consumer2 = (ConsumerImpl) consumerBuilder.consumerName("c2").subscribe(); + Awaitility.await().until(() -> consumer2.isConnected()); + + // The consumer2 will receive messages after consumer1 closed. + // Insert a delay mechanism to make the flow like below: + // 1. Close consumer1, then the 100 messages will be redelivered. + // 2. Read redeliver messages. No messages were acked at this time. + // 3. The in-flight ack of two messages is finished. + // 4. Send the messages to consumer2, consumer2 will get all the 100 messages. + CompletableFuture receiveMessageSignal2 = new CompletableFuture<>(); + org.apache.pulsar.broker.service.Consumer serviceConsumer2 = + makeConsumerReceiveMessagesDelay(topicName, subName, "c2", receiveMessageSignal2); + // step 1: close consumer. + consumer1.close(); + // step 2: wait for read messages from replay queue. + Thread.sleep(2 * 1000); + // step 3: wait for the in-flight ack. + BitSetRecyclable bitSetRecyclable = createBitSetRecyclable(100); + long ledgerId = 0, entryId = 0; + for (Message message : messagesInClientMemory) { + BatchMessageIdImpl msgId = (BatchMessageIdImpl) message.getMessageId(); + bitSetRecyclable.clear(msgId.getBatchIndex()); + ledgerId = msgId.getLedgerId(); + entryId = msgId.getEntryId(); + } + getCursor(topicName, subName).delete(PositionImpl.get(ledgerId, entryId, bitSetRecyclable.toLongArray())); + // step 4: send messages to consumer2. + receiveMessageSignal2.complete(null); + // Verify: Consumer2 will get all the 100 messages, and "unAckMessages" is 100. + List messages2 = new ArrayList<>(); + while (true) { + Message msg = consumer2.receive(2, TimeUnit.SECONDS); + if (msg == null) { + break; + } + messages2.add(msg); + } + assertEquals(messages2.size(), 100); + assertEquals(serviceConsumer2.getUnackedMessages(), 100); + // After the messages were pop out, the permits in the client memory went to 100. + Awaitility.await().untilAsserted(() -> { + assertEquals(serviceConsumer2.getAvailablePermits() + consumer2.getAvailablePermits(), + receiverQueueSize); + }); + + // cleanup. + producer.close(); + consumer2.close(); + admin.topics().delete(topicName, false); + } + + private BitSetRecyclable createBitSetRecyclable(int batchSize) { + BitSetRecyclable bitSetRecyclable = new BitSetRecyclable(batchSize); + for (int i = 0; i < batchSize; i++) { + bitSetRecyclable.set(i); + } + return bitSetRecyclable; + } + + private ManagedCursorImpl getCursor(String topic, String sub) { + PersistentTopic persistentTopic = + (PersistentTopic) pulsar.getBrokerService().getTopic(topic, false).join().get(); + PersistentDispatcherMultipleConsumers dispatcher = + (PersistentDispatcherMultipleConsumers) persistentTopic.getSubscription(sub).getDispatcher(); + return (ManagedCursorImpl) dispatcher.getCursor(); + } + + /*** + * After {@param signal} complete, the consumer({@param consumerName}) start to receive messages. + */ + private org.apache.pulsar.broker.service.Consumer makeConsumerReceiveMessagesDelay(String topic, String sub, + String consumerName, + CompletableFuture signal) throws Exception { + PersistentTopic persistentTopic = + (PersistentTopic) pulsar.getBrokerService().getTopic(topic, false).join().get(); + PersistentDispatcherMultipleConsumers dispatcher = + (PersistentDispatcherMultipleConsumers) persistentTopic.getSubscription(sub).getDispatcher(); + org.apache.pulsar.broker.service.Consumer serviceConsumer = null; + for (org.apache.pulsar.broker.service.Consumer c : dispatcher.getConsumers()){ + if (c.consumerName().equals(consumerName)) { + serviceConsumer = c; + break; + } + } + final org.apache.pulsar.broker.service.Consumer originalConsumer = serviceConsumer; + + // Insert a delay signal. + org.apache.pulsar.broker.service.Consumer spyServiceConsumer = spy(originalConsumer); + doAnswer(invocation -> { + List entries = (List) invocation.getArguments()[0]; + EntryBatchSizes batchSizes = (EntryBatchSizes) invocation.getArguments()[1]; + EntryBatchIndexesAcks batchIndexesAcks = (EntryBatchIndexesAcks) invocation.getArguments()[2]; + int totalMessages = (int) invocation.getArguments()[3]; + long totalBytes = (long) invocation.getArguments()[4]; + long totalChunkedMessages = (long) invocation.getArguments()[5]; + RedeliveryTracker redeliveryTracker = (RedeliveryTracker) invocation.getArguments()[6]; + return signal.thenApply(__ -> originalConsumer.sendMessages(entries, batchSizes, batchIndexesAcks, totalMessages, totalBytes, + totalChunkedMessages, redeliveryTracker)).join(); + }).when(spyServiceConsumer) + .sendMessages(anyList(), any(), any(), anyInt(), anyLong(), anyLong(), any()); + doAnswer(invocation -> { + List entries = (List) invocation.getArguments()[0]; + EntryBatchSizes batchSizes = (EntryBatchSizes) invocation.getArguments()[1]; + EntryBatchIndexesAcks batchIndexesAcks = (EntryBatchIndexesAcks) invocation.getArguments()[2]; + int totalMessages = (int) invocation.getArguments()[3]; + long totalBytes = (long) invocation.getArguments()[4]; + long totalChunkedMessages = (long) invocation.getArguments()[5]; + RedeliveryTracker redeliveryTracker = (RedeliveryTracker) invocation.getArguments()[6]; + long epoch = (long) invocation.getArguments()[7]; + return signal.thenApply(__ -> originalConsumer.sendMessages(entries, batchSizes, batchIndexesAcks, totalMessages, totalBytes, + totalChunkedMessages, redeliveryTracker, epoch)).join(); + }).when(spyServiceConsumer) + .sendMessages(anyList(), any(), any(), anyInt(), anyLong(), anyLong(), any(), anyLong()); + + // Replace the consumer. + Field fConsumerList = AbstractDispatcherMultipleConsumers.class.getDeclaredField("consumerList"); + Field fConsumerSet = AbstractDispatcherMultipleConsumers.class.getDeclaredField("consumerSet"); + fConsumerList.setAccessible(true); + fConsumerSet.setAccessible(true); + List consumerList = + (List) fConsumerList.get(dispatcher); + ObjectSet consumerSet = + (ObjectSet) fConsumerSet.get(dispatcher); + + consumerList.remove(originalConsumer); + consumerSet.removeAll(originalConsumer); + consumerList.add(spyServiceConsumer); + consumerSet.add(spyServiceConsumer); + return originalConsumer; + } }