From c342b980c9acdf232d04cc0f17dd4419db6fc94e Mon Sep 17 00:00:00 2001 From: Robert Roeser Date: Tue, 24 Sep 2024 00:59:53 -0700 Subject: [PATCH] 4/[thrift][checksum]port payload util free functions to a class Summary: Adds a default serializer and points the direct calls to PayloadUtils free function to the LegacyPayloadSerializerStrategy which delegates to the free functions. Reviewed By: praihan Differential Revision: D62485627 fbshipit-source-id: d52f9feb46c1cf5e75c7aa35838d6970515f3e25 --- .../DefaultPayloadSerializerStrategy.cpp | 212 +++++++++++++++++- .../DefaultPayloadSerializerStrategy.h | 103 ++++++++- .../payload/LegacyPayloadSerializerStrategy.h | 14 +- .../rocket/payload/PayloadSerializer.cpp | 68 ++++-- .../rocket/payload/PayloadSerializer.h | 70 +++--- .../payload/PayloadSerializerStrategy.h | 11 +- .../payload/test/PayloadSerializerTest.cpp | 2 + 7 files changed, 403 insertions(+), 77 deletions(-) diff --git a/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.cpp b/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.cpp index b65b1450711..f04497cd3b7 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.cpp +++ b/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.cpp @@ -16,24 +16,212 @@ #include +#include +#include +#include +#include +#include + +#include +#include + namespace apache::thrift::rocket { -template -folly::Try DefaultPayloadSerializerStrategy::unpackAsCompressed( - rocket::Payload&&, bool) { - throw std::runtime_error("not implemented"); +namespace { + +template +void applyCompressionIfNeeded( + std::unique_ptr& payload, Metadata* metadata) { + if (auto compress = metadata->compression_ref()) { + apache::thrift::rocket::detail::compressPayload(payload, *compress); + } +} + +bool validateFileDescriptor(size_t numFds, FdMetadata& fdMetadata) { + // The kernel maximum is actually much lower (at least on Linux, and + // MacOS doesn't seem to document it at all), but that will only fail in + // in `AsyncFdSocket`. + constexpr auto numFdsTypeMax = std::numeric_limits< + op::get_native_type>::max(); + + if (LIKELY(numFdsTypeMax <= numFds)) { + return true; + } else { + LOG(DFATAL) << numFds << " would overflow FdMetadata::numFds"; + fdMetadata.numFds() = numFdsTypeMax; + // This will cause "AsyncFdSocket::writeChainWithFds" to error out. + fdMetadata.fdSeqNum() = folly::SocketFds::kNoSeqNum; + return false; + } +} + +template +void handleFds( + folly::SocketFds& fds, + Metadata* metadata, + folly::AsyncTransport* transport) { + auto numFds = fds.size(); + if (numFds) { + FdMetadata fdMetadata; + if (LIKELY(validateFileDescriptor(numFds, fdMetadata))) { + // When received, the request will know to retrieve this many FDs. + fdMetadata.numFds() = numFds; + // FD sequence numbers count the total number of FDs sent on this + // socket, and are used to detect & fail on the dire class of bugs where + // the wrong FDs are about to be associated with a message. + // + // We currently require message bytes and FDs to be both sent and + // received in a coherent order, so sequence numbers here in `pack*` are + // expected to exactly match the sequencing of socket sends, and also the + // sequencing of `popNextReceivedFds` on the receiving side. + // + // NB: If `transport` is not backed by a `AsyncFdSocket*`, this will + // store `fdSeqNum == -1`, which cannot happen otherwise, thanks to + // AsyncFdSocket's 2^63 -> 0 wrap-around logic. Furthermore, the + // subsequent `writeChainWithFds` will discard `fds`. As a result, the + // recipient will see read errors on the FDs due to both `numFds` not + // matching, and `fdSeqNum` not matching. + fdMetadata.fdSeqNum() = + injectFdSocketSeqNumIntoFdsToSend(transport, &fds); + } + + DCHECK(!metadata->fdMetadata().has_value()); + metadata->fdMetadata() = fdMetadata; + } } -template -folly::Try DefaultPayloadSerializerStrategy::unpack( - rocket::Payload&&, bool) { - throw std::runtime_error("not implemented"); +} // namespace + +template +rocket::Payload DefaultPayloadSerializerStrategy::finalizePayload( + std::unique_ptr&& payload, + Metadata* metadata, + folly::SocketFds fds) { + auto ret = makePayload( + *metadata, std::move(payload)); + if (fds.size()) { + ret.fds = std::move(fds.dcheckToSendOrEmpty()); + } + return ret; } -template -Payload DefaultPayloadSerializerStrategy::pack( - PayloadType&&, folly::AsyncTransport*) { - throw std::runtime_error("not implemented"); +template +rocket::Payload DefaultPayloadSerializerStrategy::packWithFds( + Metadata* metadata, + std::unique_ptr&& payload, + folly::SocketFds fds, + folly::AsyncTransport* transport) { + applyCompressionIfNeeded(payload, metadata); + handleFds(fds, metadata, transport); + return finalizePayload(std::move(payload), metadata, std::move(fds)); } +template rocket::Payload +DefaultPayloadSerializerStrategy::packWithFds( + RequestRpcMetadata*, + std::unique_ptr&&, + folly::SocketFds, + folly::AsyncTransport*); + +template rocket::Payload +DefaultPayloadSerializerStrategy::packWithFds( + ResponseRpcMetadata*, + std::unique_ptr&&, + folly::SocketFds, + folly::AsyncTransport*); + +template rocket::Payload +DefaultPayloadSerializerStrategy::packWithFds( + StreamPayloadMetadata*, + std::unique_ptr&&, + folly::SocketFds, + folly::AsyncTransport*); + +bool DefaultPayloadSerializerStrategy:: + canSerializeMetadataIntoDataBufferHeadroom( + const std::unique_ptr& data, const size_t serSize) const { + return data && !data->isChained() && + data->headroom() >= serSize + kHeadroomBytes && !data->isSharedOne(); +} + +template +Payload DefaultPayloadSerializerStrategy::makePayloadWithHeadroom( + ProtocolWriter& writer, + const Metadata& metadata, + std::unique_ptr data) { + folly::IOBufQueue queue; + // Store previous state of the buffer pointers and rewind it. + auto startBuffer = data->buffer(); + auto start = data->data(); + auto origLen = data->length(); + data->trimEnd(origLen); + data->retreat(start - startBuffer); + + queue.append(std::move(data), false); + writer.setOutput(&queue); + auto metadataLen = metadata.write(&writer); + + // Move the new data to come right before the old data and restore the + // old tail pointer. + data = queue.move(); + data->advance(start - data->tail()); + data->append(origLen); + + return Payload::makeCombined(std::move(data), metadataLen); +} + +template +Payload DefaultPayloadSerializerStrategy::makePayloadWithoutHeadroom( + size_t serSize, + ProtocolWriter& writer, + const Metadata& metadata, + std::unique_ptr data) { + folly::IOBufQueue queue; + constexpr size_t kMinAllocBytes = 1024; + auto buf = + folly::IOBuf::create(std::max(kHeadroomBytes + serSize, kMinAllocBytes)); + buf->advance(kHeadroomBytes); + queue.append(std::move(buf)); + writer.setOutput(&queue); + auto metadataLen = metadata.write(&writer); + queue.append(std::move(data)); + return Payload::makeCombined(queue.move(), metadataLen); +} + +template +Payload DefaultPayloadSerializerStrategy::makePayload( + const Metadata& metadata, std::unique_ptr data) { + ProtocolWriter writer; + // Default is to leave some headroom for rsocket headers + size_t serSize = metadata.serializedSizeZC(&writer); + + // If possible, serialize metadata into the headeroom of data. + if (canSerializeMetadataIntoDataBufferHeadroom(data, serSize)) { + return makePayloadWithHeadroom(writer, metadata, std::move(data)); + } else { + return makePayloadWithoutHeadroom( + serSize, writer, metadata, std::move(data)); + } +} + +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const RequestRpcMetadata&, std::unique_ptr data); +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const ResponseRpcMetadata&, std::unique_ptr data); +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const StreamPayloadMetadata&, std::unique_ptr data); + +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const RequestRpcMetadata&, std::unique_ptr data); +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const ResponseRpcMetadata&, std::unique_ptr data); +template Payload DefaultPayloadSerializerStrategy:: + makePayload( + const StreamPayloadMetadata&, std::unique_ptr data); + } // namespace apache::thrift::rocket diff --git a/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.h b/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.h index fca541763d5..ab3fcaeb720 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.h +++ b/thrift/lib/cpp2/transport/rocket/payload/DefaultPayloadSerializerStrategy.h @@ -21,19 +21,30 @@ namespace apache::thrift::rocket { -// TODO rroeser - right now this is a no-op, but will be used for adding support -// for checksum and compression, and other features that don't delegate to the -// PayloadUtils.h header free functions. +/** + * Port of PayloadUtils.h header free functions into a strategy class. + */ class DefaultPayloadSerializerStrategy final : public PayloadSerializerStrategy { public: DefaultPayloadSerializerStrategy() : PayloadSerializerStrategy(*this) {} template - folly::Try unpackAsCompressed(rocket::Payload&& payload, bool useBinary); + folly::Try unpackAsCompressed(rocket::Payload&& payload) { + return folly::makeTryWith([&]() { + T t = unpackImpl(std::move(payload)); + if (auto compression = t.metadata.compression()) { + t.payload = uncompressBuffer(std::move(t.payload), *compression); + } + return std::move(t); + }); + } template - folly::Try unpack(rocket::Payload&& payload, bool useBinary); + folly::Try unpack(rocket::Payload&& payload) { + return folly::makeTryWith( + [&]() { return unpackImpl(std::move(payload)); }); + } template rocket::Payload packWithFds( @@ -43,12 +54,88 @@ class DefaultPayloadSerializerStrategy final folly::AsyncTransport* transport); template - size_t unpackCompact(T&, const folly::IOBuf*) { - throw std::runtime_error("not implemented"); + std::unique_ptr packCompact(T&& data) { + CompactProtocolWriter writer; + folly::IOBufQueue queue; + writer.setOutput(&queue); + data.write(&writer); + return queue.move(); + } + + template + size_t unpackCompact(T& output, const folly::IOBuf* buffer) { + if (FOLLY_UNLIKELY(!buffer)) { + folly::throw_exception("Underflow"); + } + CompactProtocolReader reader; + reader.setInput(buffer); + output.read(&reader); + return reader.getCursorPosition(); } template - Payload pack(PayloadType&& payload, folly::AsyncTransport* transport); + rocket::Payload pack( + PayloadType&& payload, folly::AsyncTransport* transport) { + auto metadata = std::forward(payload).metadata; + return packWithFds( + &metadata, + std::forward(payload).payload, + std::forward(payload).fds, + transport); + } + + private: + static constexpr size_t kHeadroomBytes = 16; + + template + rocket::Payload finalizePayload( + std::unique_ptr&& payload, + Metadata* metadata, + folly::SocketFds fds); + + bool canSerializeMetadataIntoDataBufferHeadroom( + const std::unique_ptr& data, const size_t serSize) const; + + template + Payload makePayloadWithHeadroom( + ProtocolWriter& writer, + const Metadata& metadata, + std::unique_ptr data); + + template + Payload makePayloadWithoutHeadroom( + size_t serSize, + ProtocolWriter& writer, + const Metadata& metadata, + std::unique_ptr data); + + template + Payload makePayload( + const Metadata& metadata, std::unique_ptr data); + + void verifyMetadataSize(size_t metadataSize, size_t expectedSize) { + if (metadataSize != expectedSize) { + folly::throw_exception("metadata size mismatch"); + } + } + + template + T unpackImpl(rocket::Payload&& payload) { + T t{{}, {}}; + unpackPayloadMetadata(t, payload); + t.payload = std::move(payload).data(); + return t; + } + + template + void unpackPayloadMetadata(T& t, rocket::Payload& payload) { + if (payload.hasNonemptyMetadata()) { + if (unpackCompact(t.metadata, payload.buffer()) != + payload.metadataSize()) { + folly::throw_exception("metadata size mismatch"); + } + } + } }; } // namespace apache::thrift::rocket diff --git a/thrift/lib/cpp2/transport/rocket/payload/LegacyPayloadSerializerStrategy.h b/thrift/lib/cpp2/transport/rocket/payload/LegacyPayloadSerializerStrategy.h index d15f128ab33..c378ace888b 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/LegacyPayloadSerializerStrategy.h +++ b/thrift/lib/cpp2/transport/rocket/payload/LegacyPayloadSerializerStrategy.h @@ -37,8 +37,8 @@ class LegacyPayloadSerializerStrategy final } template - FOLLY_ERASE folly::Try unpack(rocket::Payload&& payload, bool useBinary) { - return ::apache::thrift::rocket::unpack(std::move(payload), useBinary); + FOLLY_ERASE folly::Try unpack(rocket::Payload&& payload) { + return ::apache::thrift::rocket::unpack(std::move(payload)); } template @@ -57,6 +57,16 @@ class LegacyPayloadSerializerStrategy final return ::apache::thrift::rocket::pack( std::forward(payload), transport); } + + template + FOLLY_ERASE rocket::Payload packWithFds( + Metadata* metadata, + std::unique_ptr&& payload, + folly::SocketFds fds, + folly::AsyncTransport* transport) { + return ::apache::thrift::rocket::packWithFds( + metadata, std::move(payload), std::move(fds), transport); + } }; } // namespace apache::thrift::rocket diff --git a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.cpp b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.cpp index f2f2f7d40ee..ec880e0a280 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.cpp +++ b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.cpp @@ -16,41 +16,67 @@ #include -#include - namespace apache::thrift::rocket { -namespace { -auto initializationLock = folly::Singleton().shouldEagerInit(); -static folly::Indestructible> serializer; -} // namespace +PayloadSerializer::PayloadSerializerHolder::~PayloadSerializerHolder() { + PayloadSerializer* serializer = serializer_; + if (serializer) { + delete serializer; + } +} + +PayloadSerializer& PayloadSerializer::PayloadSerializerHolder::get() { + auto* serializer = serializer_.load(std::memory_order_relaxed); -bool isSerializerInitialized() { - return serializer->has_value(); + // Fast path when the serializer is already initialized + if (FOLLY_LIKELY(serializer != nullptr)) { + return *serializer; + } else { + // Slow path when the serializer is not initialized yet that + // uses a compare-and-swap to initialize it. Avoids the need for a lock. + auto* newSerializer = + new PayloadSerializer(LegacyPayloadSerializerStrategy()); + for (;;) { + // Load the current serializer + auto* expected = serializer_.load(std::memory_order_relaxed); + + // Check if the serializer is already initialized + if (expected == nullptr) { + // Try to initialize the serializer + if (serializer_.compare_exchange_strong( + expected, newSerializer, std::memory_order_release)) { + return *newSerializer; + } + } else { + // The serializer is already initialized, return it and clean up. + delete newSerializer; + return *expected; + } + } + } } -void PayloadSerializer::tryInitialize(PayloadSerializer&& src) { - std::lock_guard lock(*initializationLock.try_get()); - if (!isSerializerInitialized()) { - serializer->emplace(std::move(src)); +void PayloadSerializer::PayloadSerializerHolder::reset() { + auto* serializer = serializer_.exchange(nullptr); + if (serializer) { + delete serializer; } } -void PayloadSerializer::tryInitializeDefault() { - tryInitializeEmplace(PayloadSerializer(LegacyPayloadSerializerStrategy())); +PayloadSerializer::PayloadSerializerHolder& +PayloadSerializer::getPayloadSerializerHolder() { + static folly::Indestructible + holder; + + return *holder; } PayloadSerializer& PayloadSerializer::getInstance() { - if (!FOLLY_UNLIKELY(isSerializerInitialized())) { - tryInitializeDefault(); - } - std::optional& opt = *serializer; - return *opt; + return getPayloadSerializerHolder().get(); } void PayloadSerializer::reset() { - std::lock_guard lock(*initializationLock.try_get()); - serializer->reset(); + getPayloadSerializerHolder().reset(); } } // namespace apache::thrift::rocket diff --git a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.h b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.h index a83e903f3f8..39ab38a4415 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.h +++ b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializer.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -40,20 +41,43 @@ class PayloadSerializer { variant strategy_; - template - explicit PayloadSerializer(Strategy s) : strategy_(std::move(s)) {} + struct PayloadSerializerHolder { + PayloadSerializerHolder() {} + ~PayloadSerializerHolder(); + + PayloadSerializer& get(); + + template + void initialize(Strategy&& strategy) { + auto* serializer = + new PayloadSerializer(std::forward(strategy)); + auto* other = serializer_.exchange(serializer); + if (other) { + delete other; + } + } + + void reset(); + + private: + std::atomic serializer_{nullptr}; + }; public: + template < + typename Strategy, + typename = std::enable_if_t< + std::is_base_of_v, Strategy>>> + explicit PayloadSerializer(Strategy s) : strategy_(std::move(s)) {} /** * Lets you override the strategy to one of the supported strategies instead * of the default. Must be called before the getInstance() method is called * for the first time to take effect. Otherwise, it will be ignored. */ template - static void initialize(Strategy strategy) { - tryInitialize(PayloadSerializer(std::move(strategy))); + static void initialize(Strategy&& strategy) { + getPayloadSerializerHolder().initialize(std::forward(strategy)); } - /** * Returns the singleton instance of the PayloadSerializer. Either returns * the default strategy or the one that was overridden by the initialize() @@ -62,27 +86,23 @@ class PayloadSerializer { static PayloadSerializer& getInstance(); template - folly::Try unpackAsCompressed(Payload&& payload, bool useBinary) { + FOLLY_ERASE folly::Try unpackAsCompressed(Payload&& payload) { return folly::variant_match( - strategy_, - [payload = std::move(payload), useBinary](auto& strategy) mutable { - return strategy.template unpackAsCompressed( - std::move(payload), useBinary); + strategy_, [payload = std::move(payload)](auto& strategy) mutable { + return strategy.template unpackAsCompressed(std::move(payload)); }); } template - folly::Try unpack(rocket::Payload&& payload, bool useBinary) { + FOLLY_ERASE folly::Try unpack(rocket::Payload&& payload) { return folly::variant_match( - strategy_, - [payload = std::move(payload), - useBinary = useBinary](auto& strategy) mutable { - return strategy.template pack(std::move(payload), useBinary); + strategy_, [payload = std::move(payload)](auto& strategy) mutable { + return strategy.template unpack(std::move(payload)); }); } template - std::unique_ptr packCompact(T&& data) { + FOLLY_ERASE std::unique_ptr packCompact(T&& data) { return folly::variant_match( strategy_, [data = std::forward(data)](auto& strategy) mutable { return strategy.packCompact(std::forward(data)); @@ -90,7 +110,7 @@ class PayloadSerializer { } template - size_t unpackCompact(T& output, const folly::IOBuf* buffer) { + FOLLY_ERASE size_t unpackCompact(T& output, const folly::IOBuf* buffer) { return folly::variant_match( strategy_, [&output, buffer](auto& strategy) mutable { return strategy.unpackCompact(output, buffer); @@ -98,7 +118,7 @@ class PayloadSerializer { } template - rocket::Payload packWithFds( + FOLLY_ERASE rocket::Payload packWithFds( Metadata* metadata, std::unique_ptr&& payload, folly::SocketFds fds, @@ -115,13 +135,14 @@ class PayloadSerializer { } template - Payload pack(PayloadType&& payload, folly::AsyncTransport* transport) { + FOLLY_ERASE Payload + pack(PayloadType&& payload, folly::AsyncTransport* transport) { return folly::variant_match( strategy_, [payload = std::forward(payload), transport](auto& strategy) mutable { return strategy.template pack( - std::forward(payload)); + std::forward(payload), transport); }); } @@ -132,14 +153,7 @@ class PayloadSerializer { */ static void reset(); - static void tryInitialize(PayloadSerializer&& src); - - template - static void tryInitializeEmplace(Args&&... args) { - tryInitialize(std::forward(args)...); - } - - static void tryInitializeDefault(); + static PayloadSerializerHolder& getPayloadSerializerHolder(); FRIEND_TEST(PayloadSerializerTest, TestPackWithLegacyStrategy); FRIEND_TEST(PayloadSerializerTest, TestPackWitDefaultyStrategy); diff --git a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializerStrategy.h b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializerStrategy.h index 0d3610ab817..2f06c240e0f 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializerStrategy.h +++ b/thrift/lib/cpp2/transport/rocket/payload/PayloadSerializerStrategy.h @@ -29,14 +29,13 @@ template class PayloadSerializerStrategy { public: template - FOLLY_ERASE folly::Try unpackAsCompressed( - Payload&& payload, bool useBinary) { - return child_.unpackAsCompressed(std::move(payload), useBinary); + FOLLY_ERASE folly::Try unpackAsCompressed(Payload&& payload) { + return child_.unpackAsCompressed(std::move(payload)); } template - FOLLY_ERASE folly::Try unpack(Payload&& payload, bool useBinary) { - return child_.unpack(std::move(payload), useBinary); + FOLLY_ERASE folly::Try unpack(Payload&& payload) { + return child_.unpack(std::move(payload)); } template @@ -61,7 +60,7 @@ class PayloadSerializerStrategy { } template - size_t unpackCompact(T& output, const folly::IOBuf* buffer) { + FOLLY_ERASE size_t unpackCompact(T& output, const folly::IOBuf* buffer) { return child_.unpackCompact(output, buffer); } diff --git a/thrift/lib/cpp2/transport/rocket/payload/test/PayloadSerializerTest.cpp b/thrift/lib/cpp2/transport/rocket/payload/test/PayloadSerializerTest.cpp index 67822b22f1f..0cdcfa58faf 100644 --- a/thrift/lib/cpp2/transport/rocket/payload/test/PayloadSerializerTest.cpp +++ b/thrift/lib/cpp2/transport/rocket/payload/test/PayloadSerializerTest.cpp @@ -31,6 +31,7 @@ void testPackAndUnpackWithCompactProtocol(PayloadSerializer& serializer) { RequestRpcMetadata other; serializer.unpackCompact(other, payload.get()); EXPECT_EQ(other, metadata); + EXPECT_EQ(other.protocol(), ProtocolId::COMPACT); } TEST(PayloadSerializerTest, TestPackWithLegacyStrategy) { @@ -42,6 +43,7 @@ TEST(PayloadSerializerTest, TestPackWithLegacyStrategy) { TEST(PayloadSerializerTest, TestPackWitDefaultyStrategy) { PayloadSerializer::reset(); + PayloadSerializer::initialize(DefaultPayloadSerializerStrategy()); auto& serializer = PayloadSerializer::getInstance(); testPackAndUnpackWithCompactProtocol(serializer); }