Skip to content

Commit

Permalink
Fix race condition of hasMessageAvailable by reading latest startMess…
Browse files Browse the repository at this point in the history
…ageId each time

### Motivation

See apache/pulsar-client-python#199

There is a race condition when `hasMessageAvailable` is called after
`seek` if the start message ID of `Reader` is earliest.

In `ConsumerImpl::hasMessageAvailableAsync`, if the connection is not
established at the moment, `lastDequedMessageId_` will be `earliest`
because no message is received. Since `lastMessageIdInBroker_` is also
`earliest`, `getLastMessageIdAsync` will be called and then it comes at

https://github.com/apache/pulsar-client-cpp/blob/e2cacb7dfb57b6d059b49fead2e1611548ff89b0/lib/ConsumerImpl.cc#L1554

However, before `getLastMessageIdAsync` is called, `messageId` was
`earliest` because `lastDequedMessageId_` and `startMessageId_` were
both `earliest`. However, when the callback is called, the
`startMessageId_` has already been updated to `latest` in
`connectionOpened`, so we should compare to `latest`.

### Modifications

In the callback of `getLastMessageIdAsync`, retrieve the latest value of
`startMessageId_` to compare rather then reusing the old value.

Refactor the seek flow to reset the seek states and trigger the callback
after updating the `startMessageId_`.

`ReaderTest.testHasMessageAvailableAfterSeekToEnd` is added to cover the
changes.
  • Loading branch information
BewareMyPower committed Mar 5, 2024
1 parent 747c186 commit e6e0adc
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 59 deletions.
158 changes: 105 additions & 53 deletions lib/ConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <pulsar/MessageIdBuilder.h>

#include <algorithm>
#include <sstream>
#include <utility>

#include "AckGroupingTracker.h"
#include "AckGroupingTrackerDisabled.h"
Expand Down Expand Up @@ -236,18 +238,34 @@ Future<Result, bool> ConsumerImpl::connectionOpened(const ClientConnectionPtr& c
// sending the subscribe request.
cnx->registerConsumer(consumerId_, get_shared_this_ptr());

if (duringSeek_) {
boost::optional<MessageId> startMessageId = boost::none;
Lock seekLock{mutexForSeek_};
if (seekStatus_ != SeekStatus::NOT_STARTED) {
ackGroupingTrackerPtr_->flushAndClean();
if (seekStatus_ == SeekStatus::COMPLETED) {
auto seekCallback = std::move(seekCallback_);
auto seekMessageId = std::move(seekMessageId_);
seekStatus_ = SeekStatus::NOT_STARTED;
LOG_INFO(getName() << "Seek successfully");
seekLock.unlock();

{
std::lock_guard<std::mutex> lockForMessageId{mutexForMessageId_};
startMessageId = startMessageId_ = seekMessageId;
}

if (seekCallback) {
seekCallback(ResultOk);
}
}
} else {
seekLock.unlock();
std::lock_guard<std::mutex> lockForMessageId(mutexForMessageId_);
startMessageId = startMessageId_ = clearReceiveQueue();
}

Lock lockForMessageId(mutexForMessageId_);
// Update startMessageId so that we can discard messages after delivery restarts
const auto startMessageId = clearReceiveQueue();
const auto subscribeMessageId =
(subscriptionMode_ == Commands::SubscriptionModeNonDurable) ? startMessageId : boost::none;
startMessageId_ = startMessageId;
lockForMessageId.unlock();

unAckedMessageTrackerPtr_->clear();

ClientImplPtr client = client_.lock();
Expand Down Expand Up @@ -586,11 +604,13 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
// try convert key value data.
m.impl_->convertPayloadToKeyValue(config_.getSchema());

const auto startMessageId = startMessageId_.get();
Lock messageIdLock{mutexForMessageId_};
const auto startMessageId = startMessageId_;
messageIdLock.unlock();
if (isPersistent_ && startMessageId &&
m.getMessageId().ledgerId() == startMessageId.value().ledgerId() &&
m.getMessageId().entryId() == startMessageId.value().entryId() &&
isPriorEntryIndex(m.getMessageId().entryId())) {
isPriorEntryIndex(m.getMessageId().entryId(), startMessageId.value())) {
LOG_DEBUG(getName() << " Ignoring message from before the startMessageId: "
<< startMessageId.value());
return;
Expand Down Expand Up @@ -712,7 +732,9 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
auto batchSize = batchedMessage.impl_->metadata.num_messages_in_batch();
LOG_DEBUG("Received Batch messages of size - " << batchSize
<< " -- msgId: " << batchedMessage.getMessageId());
const auto startMessageId = startMessageId_.get();
Lock messageIdLock{mutexForMessageId_};
const auto startMessageId = startMessageId_;
messageIdLock.unlock();

int skippedMessages = 0;

Expand Down Expand Up @@ -744,7 +766,7 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
// to the startMessageId
if (isPersistent_ && msgId.ledgerId() == startMessageId.value().ledgerId() &&
msgId.entryId() == startMessageId.value().entryId() &&
isPriorBatchIndex(msgId.batchIndex())) {
isPriorBatchIndex(msgId.batchIndex(), startMessageId.value())) {
LOG_DEBUG(getName() << "Ignoring message from before the startMessageId"
<< msg.getMessageId());
++skippedMessages;
Expand Down Expand Up @@ -1050,11 +1072,8 @@ void ConsumerImpl::messageProcessed(Message& msg, bool track) {
* not seen by the application
*/
boost::optional<MessageId> ConsumerImpl::clearReceiveQueue() {
bool expectedDuringSeek = true;
if (duringSeek_.compare_exchange_strong(expectedDuringSeek, false)) {
return seekMessageId_.get();
} else if (subscriptionMode_ == Commands::SubscriptionModeDurable) {
return startMessageId_.get();
if (subscriptionMode_ == Commands::SubscriptionModeDurable) {
return startMessageId_;
}
Message nextMessageInQueue;
if (incomingMessages_.peekAndClear(nextMessageInQueue)) {
Expand All @@ -1080,7 +1099,7 @@ boost::optional<MessageId> ConsumerImpl::clearReceiveQueue() {
} else {
// No message was received or dequeued by this consumer. Next message would still be the
// startMessageId
return startMessageId_.get();
return startMessageId_;
}
}

Expand Down Expand Up @@ -1500,17 +1519,10 @@ void ConsumerImpl::seekAsync(uint64_t timestamp, ResultCallback callback) {

bool ConsumerImpl::isReadCompacted() { return readCompacted_; }

inline bool hasMoreMessages(const MessageId& lastMessageIdInBroker, const MessageId& messageId) {
return lastMessageIdInBroker > messageId && lastMessageIdInBroker.entryId() != -1;
}

void ConsumerImpl::hasMessageAvailableAsync(HasMessageAvailableCallback callback) {
const auto startMessageId = startMessageId_.get();
Lock lock(mutexForMessageId_);
const auto messageId =
(lastDequedMessageId_ == MessageId::earliest()) ? startMessageId.value() : lastDequedMessageId_;

if (messageId == MessageId::latest()) {
if (lastDequedMessageId_ == MessageId::earliest() &&
startMessageId_.value_or(MessageId::earliest()) == MessageId::latest()) {
lock.unlock();
auto self = get_shared_this_ptr();
getLastMessageIdAsync([self, callback](Result result, const GetLastMessageIdResponse& response) {
Expand Down Expand Up @@ -1543,16 +1555,18 @@ void ConsumerImpl::hasMessageAvailableAsync(HasMessageAvailableCallback callback
}
});
} else {
if (hasMoreMessages(lastMessageIdInBroker_, messageId)) {
if (hasMoreMessages()) {
lock.unlock();
callback(ResultOk, true);
return;
}
lock.unlock();

getLastMessageIdAsync([callback, messageId](Result result, const GetLastMessageIdResponse& response) {
callback(result, (result == ResultOk) && hasMoreMessages(response.getLastMessageId(), messageId));
});
auto self = get_shared_this_ptr();
getLastMessageIdAsync()
.thenApply<bool>([this, self](const GetLastMessageIdResponse& response) {
std::lock_guard<std::mutex> lock{mutexForMessageId_};
return hasMoreMessages();
})
.addListener(std::move(callback));
}
}

Expand Down Expand Up @@ -1647,6 +1661,30 @@ bool ConsumerImpl::isConnected() const { return !getCnx().expired() && state_ ==

uint64_t ConsumerImpl::getNumberOfConnectedConsumer() { return isConnected() ? 1 : 0; }

static std::ostream& operator<<(std::ostream& os, const std::pair<MessageId, long>& seekInfo) {
if (seekInfo.second == 0L) {
os << seekInfo.first;
} else {
os << seekInfo.second;
}
return os;
}

static std::ostream& operator<<(std::ostream& os, const SeekStatus& status) {
switch (status) {
case SeekStatus::NOT_STARTED:
os << "not started";
break;
case SeekStatus::IN_PROGRESS:
os << "in progress";
break;
case SeekStatus::COMPLETED:
os << "completed";
break;
}
return os;
}

void ConsumerImpl::seekAsyncInternal(long requestId, SharedBuffer seek, const MessageId& seekId,
long timestamp, ResultCallback callback) {
ClientConnectionPtr cnx = getCnx().lock();
Expand All @@ -1656,49 +1694,63 @@ void ConsumerImpl::seekAsyncInternal(long requestId, SharedBuffer seek, const Me
return;
}

const auto originalSeekMessageId = seekMessageId_.get();
seekMessageId_ = seekId;
duringSeek_ = true;
if (timestamp > 0) {
LOG_INFO(getName() << " Seeking subscription to " << timestamp);
} else {
LOG_INFO(getName() << " Seeking subscription to " << seekId);
Lock seekLock{mutexForSeek_};
if (seekStatus_ != SeekStatus::NOT_STARTED) {
auto seekStatus = seekStatus_;
seekLock.unlock();
LOG_WARN(getName() << " attempted to seek " << std::make_pair(seekId, timestamp)
<< " when the status is " << seekStatus);
callback(ResultNotAllowedError);
return;
}
seekStatus_ = SeekStatus::IN_PROGRESS;
const auto originalSeekMessageId = seekMessageId_;
seekMessageId_ = seekId;
seekLock.unlock();

LOG_INFO(getName() << " Seeking subscription to " << std::make_pair(seekId, timestamp));

std::weak_ptr<ConsumerImpl> weakSelf{get_shared_this_ptr()};

cnx->sendRequestWithId(seek, requestId)
.addListener([this, weakSelf, callback, originalSeekMessageId](Result result,
const ResponseData& responseData) {
.addListener([this, weakSelf, callback, originalSeekMessageId, timestamp](
Result result, const ResponseData& responseData) {
auto self = weakSelf.lock();
if (!self) {
callback(result);
return;
}
if (result == ResultOk) {
LOG_INFO(getName() << "Seek successfully");
ackGroupingTrackerPtr_->flushAndClean();
incomingMessages_.clear();
Lock lock(mutexForMessageId_);
{
std::lock_guard<std::mutex> seekLock{mutexForSeek_};
seekCallback_ = std::move(callback);
seekStatus_ = SeekStatus::COMPLETED;
}

std::lock_guard<std::mutex> lock{mutexForMessageId_};
lastDequedMessageId_ = MessageId::earliest();
lock.unlock();
} else {
LOG_ERROR(getName() << "Failed to seek: " << result);
LOG_ERROR(getName() << "Failed to seek " << std::make_pair(originalSeekMessageId, timestamp)
<< ": " << result);
Lock seekLock{mutexForSeek_};
seekMessageId_ = originalSeekMessageId;
duringSeek_ = false;
seekStatus_ = SeekStatus::NOT_STARTED;
seekLock.unlock();
callback(result);
}
callback(result);
});
}

bool ConsumerImpl::isPriorBatchIndex(int32_t idx) {
return config_.isStartMessageIdInclusive() ? idx < startMessageId_.get().value().batchIndex()
: idx <= startMessageId_.get().value().batchIndex();
bool ConsumerImpl::isPriorBatchIndex(int32_t idx, const MessageId& startMessageId) {
return config_.isStartMessageIdInclusive() ? idx < startMessageId.batchIndex()
: idx <= startMessageId.batchIndex();
}

bool ConsumerImpl::isPriorEntryIndex(int64_t idx) {
return config_.isStartMessageIdInclusive() ? idx < startMessageId_.get().value().entryId()
: idx <= startMessageId_.get().value().entryId();
bool ConsumerImpl::isPriorEntryIndex(int64_t idx, const MessageId& startMessageId) {
return config_.isStartMessageIdInclusive() ? idx < startMessageId.entryId()
: idx <= startMessageId.entryId();
}

bool ConsumerImpl::hasEnoughMessagesForBatchReceive() const {
Expand Down
54 changes: 48 additions & 6 deletions lib/ConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ const static std::string SYSTEM_PROPERTY_REAL_TOPIC = "REAL_TOPIC";
const static std::string PROPERTY_ORIGIN_MESSAGE_ID = "ORIGIN_MESSAGE_ID";
const static std::string DLQ_GROUP_TOPIC_SUFFIX = "-DLQ";

enum class SeekStatus : std::uint8_t
{
NOT_STARTED,
IN_PROGRESS,
COMPLETED
};

class ConsumerImpl : public ConsumerImplBase {
public:
ConsumerImpl(const ClientImplPtr client, const std::string& topic, const std::string& subscriptionName,
Expand Down Expand Up @@ -174,8 +181,8 @@ class ConsumerImpl : public ConsumerImplBase {
void drainIncomingMessageQueue(size_t count);
uint32_t receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, Message& batchedMessage,
const BitSet& ackSet, int redeliveryCount);
bool isPriorBatchIndex(int32_t idx);
bool isPriorEntryIndex(int64_t idx);
bool isPriorBatchIndex(int32_t idx, const MessageId& startMessageId);
bool isPriorEntryIndex(int64_t idx, const MessageId& startMessageId);
void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, BrokerConsumerStatsCallback);

bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
Expand All @@ -193,6 +200,7 @@ class ConsumerImpl : public ConsumerImplBase {
const DeadlineTimerPtr& timer,
BrokerGetLastMessageIdCallback callback);

// This method must be called when `mutexForMessageId_` is held
boost::optional<MessageId> clearReceiveQueue();
void seekAsyncInternal(long requestId, SharedBuffer seek, const MessageId& seekId, long timestamp,
ResultCallback callback);
Expand Down Expand Up @@ -234,14 +242,16 @@ class ConsumerImpl : public ConsumerImplBase {
std::shared_ptr<Promise<Result, Producer>> deadLetterProducer_;
std::mutex createProducerLock_;

// Make the access to `lastDequedMessageId_` and `lastMessageIdInBroker_` thread safe
// Make the access to `startMessageId_`, `lastDequedMessageId_` and `lastMessageIdInBroker_` thread safe
mutable std::mutex mutexForMessageId_;
boost::optional<MessageId> startMessageId_;
MessageId lastDequedMessageId_{MessageId::earliest()};
MessageId lastMessageIdInBroker_{MessageId::earliest()};

std::atomic_bool duringSeek_{false};
Synchronized<boost::optional<MessageId>> startMessageId_;
Synchronized<MessageId> seekMessageId_{MessageId::earliest()};
mutable std::mutex mutexForSeek_;
SeekStatus seekStatus_{SeekStatus::NOT_STARTED};
MessageId seekMessageId_{MessageId::earliest()};
ResultCallback seekCallback_{nullptr};

class ChunkedMessageCtx {
public:
Expand Down Expand Up @@ -332,6 +342,38 @@ class ConsumerImpl : public ConsumerImplBase {
const proto::MessageIdData& messageIdData,
const ClientConnectionPtr& cnx, MessageId& messageId);

Future<Result, GetLastMessageIdResponse> getLastMessageIdAsync() {
Promise<Result, GetLastMessageIdResponse> promise;
getLastMessageIdAsync([promise](Result result, const GetLastMessageIdResponse& response) {
if (result == ResultOk) {
promise.setValue(response);
} else {
promise.setFailed(result);
}
});
return promise.getFuture();
}

// This method must be called when mutexForMessageId_ is held
bool hasMoreMessages() const {
if (lastMessageIdInBroker_.entryId() == -1L) {
// Need to get last message ID from broker
return false;
}
if (lastDequedMessageId_ == MessageId::earliest()) {
// No message is received, compare with the start message ID
auto startMessageId = startMessageId_.value_or(MessageId::latest());
// TODO: we need to wait until startMessageId becomes latest
if (config_.isStartMessageIdInclusive()) {
return lastMessageIdInBroker_ >= startMessageId;
} else {
return lastMessageIdInBroker_ > startMessageId;
}
} else {
return lastMessageIdInBroker_ > lastDequedMessageId_;
}
}

friend class PulsarFriend;
friend class MultiTopicsConsumerImpl;

Expand Down
Loading

0 comments on commit e6e0adc

Please sign in to comment.