Skip to content
This repository has been archived by the owner on Jan 7, 2022. It is now read-only.

Commit

Permalink
Control where Thrift/Future client callbacks executed
Browse files Browse the repository at this point in the history
Summary:
There are multiple ways of calling Thrift RPC on the client side with respect to how result of the call is delivered:
1. sync: blocking call, client thread is waiting.
2. async: asynchronous, result delivered through callback
3. Future-based: caller gets Future that is fulfilled when response is received
4. SemiFuture: the same as above but with SemiFuture instead of Future
5. coroutine (if enabled): caller gets coro task which is completed with RPC result

This abundance of APIs gives us a problem because on our side we need to integrate them somehow with our Worker-based threading model.
### How we will do it
#1 is trivial (if you call it from Worker you will get result in the same Worker thread but nonetheless you should never use it in production because it consumes the whole Worker).
#4 and #5 trivial as well because caller has to provide execution context for getting result anyway and can simply use Worker.
But #2 is problematic because by default it will deliver response in Thrift's own IO thread even if client method has been called from Worker (as well as #3 but this one should never be used in practice because Futures are deprecated).

### How to solve async methods?
Async methods require callback to passed. We will provide option to set callback executor when creating a new Thrift client and when async_ method is called we will wrap client-provided callback into one which simply delegates execution to this callback executor. As a nice side benefit It will work with Future methods as well.

Reviewed By: MohamedBassem

Differential Revision: D22866545

fbshipit-source-id: 781b36ba38506cd9a1c78162cd1127afab7ed39e
  • Loading branch information
747mmHg authored and facebook-github-bot committed Aug 27, 2020
1 parent a0879b2 commit a8fe5d2
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 40 deletions.
110 changes: 110 additions & 0 deletions logdevice/common/thrift/RocketChannelWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* Copyright (c) 2017-present, Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "logdevice/common/thrift/RocketChannelWrapper.h"

using apache::thrift::ClientReceiveState;
using apache::thrift::RequestClientCallback;
using CallbackPtr = RequestClientCallback::Ptr;

namespace facebook { namespace logdevice { namespace detail {
namespace {
template <bool oneWay>
/**
* Ensures client provided callback is called from provided executor unless they
* are safe to call inline. User-defined callbacks are never safe by default but
* most of Thrift callbacks are (e.g. callback adapting async call to
* Future/SemiFuture/coroutines API).
*/
class ExecutorRequestCallback final : public RequestClientCallback {
public:
ExecutorRequestCallback(CallbackPtr cb,
folly::Executor::KeepAlive<> executor_keep_alive)
: executor_keep_alive_(std::move(executor_keep_alive)),
cb_(std::move(cb)) {
ld_check(executor_keep_alive_);
}

void onRequestSent() noexcept override {
if (oneWay) {
executor_keep_alive_->add(
[cb = std::move(cb_)]() mutable { cb.release()->onRequestSent(); });
delete this;
} else {
requestSent_ = true;
}
}
void onResponse(ClientReceiveState&& rs) noexcept override {
executor_keep_alive_->add([requestSent = requestSent_,
cb = std::move(cb_),
rs = std::move(rs)]() mutable {
if (requestSent) {
cb->onRequestSent();
}
cb.release()->onResponse(std::move(rs));
});
delete this;
}
void onResponseError(folly::exception_wrapper ex) noexcept override {
executor_keep_alive_->add([requestSent = requestSent_,
cb = std::move(cb_),
ex = std::move(ex)]() mutable {
if (requestSent) {
cb->onRequestSent();
}
cb.release()->onResponseError(std::move(ex));
});
delete this;
}

private:
bool requestSent_{false};
folly::Executor::KeepAlive<> executor_keep_alive_;
CallbackPtr cb_;
};

} // namespace

template <bool oneWayCb>
CallbackPtr RocketChannelWrapper::wrapIfUnsafe(CallbackPtr cob) {
if (!cob->isInlineSafe()) {
return CallbackPtr(new ExecutorRequestCallback<oneWayCb>(
std::move(cob), getKeepAliveToken(callback_executor_)));
} else {
return cob;
}
}

void RocketChannelWrapper::sendRequestResponse(
const apache::thrift::RpcOptions& options,
folly::StringPiece method_name,
apache::thrift::SerializedRequest&& request,
std::shared_ptr<apache::thrift::transport::THeader> header,
CallbackPtr cob) {
cob = wrapIfUnsafe<false>(std::move(cob));
channel_->sendRequestResponse(options,
method_name,
std::move(request),
std::move(header),
std::move(cob));
}

void RocketChannelWrapper::sendRequestNoResponse(
const apache::thrift::RpcOptions& options,
folly::StringPiece method_name,
apache::thrift::SerializedRequest&& request,
std::shared_ptr<apache::thrift::transport::THeader> header,
CallbackPtr cob) {
cob = wrapIfUnsafe<true>(std::move(cob));
channel_->sendRequestNoResponse(options,
method_name,
std::move(request),
std::move(header),
std::move(cob));
}
}}} // namespace facebook::logdevice::detail
45 changes: 26 additions & 19 deletions logdevice/common/thrift/RocketChannelWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace facebook { namespace logdevice { namespace detail {

/**
* Wraps RocketClientChannel to allow calling some of its methods (such as
* setTimeout and d-tor) outside of EventBase loop.
* setTimeout and d-tor) outside of EventBase loop. Also ensures Thrift
* callbacks are called from specified executor.
*
* This object is not thread-safe, concurrent usages of the same object from
* different threads will lead to undefined behaviour.
Expand All @@ -32,37 +33,35 @@ class RocketChannelWrapper : public apache::thrift::RequestChannel {
* given event base. The created wrapper takes ownership of channel but tt is
* caller's responsibilty to ensure the passed event base out-lives the
* wrapper.
* @param channel Underlying Rocket transport
* @param evb Event base which will be used for all operation on
* underlying channel
* @param callback_executor Thrift callbacks (for async methods) and Future
* callbacks (for future_ methods) will run on this
* executor. If null then IO thread will run
* callbacks.
*/
static Ptr newChannel(apache::thrift::RocketClientChannel::Ptr channel,
folly::EventBase* evb) {
return {new RocketChannelWrapper(std::move(channel), evb), {}};
folly::EventBase* evb,
folly::Executor* callback_executor) {
return {
new RocketChannelWrapper(std::move(channel), evb, callback_executor),
{}};
}

void sendRequestResponse(
const apache::thrift::RpcOptions& options,
folly::StringPiece method_name,
apache::thrift::SerializedRequest&& request,
std::shared_ptr<apache::thrift::transport::THeader> header,
apache::thrift::RequestClientCallback::Ptr cob) override {
channel_->sendRequestResponse(options,
method_name,
std::move(request),
std::move(header),
std::move(cob));
}
apache::thrift::RequestClientCallback::Ptr cob) override;

void sendRequestNoResponse(
const apache::thrift::RpcOptions& options,
folly::StringPiece method_name,
apache::thrift::SerializedRequest&& request,
std::shared_ptr<apache::thrift::transport::THeader> header,
apache::thrift::RequestClientCallback::Ptr cob) override {
channel_->sendRequestNoResponse(options,
method_name,
std::move(request),
std::move(header),
std::move(cob));
}
apache::thrift::RequestClientCallback::Ptr cob) override;

void
sendRequestStream(const apache::thrift::RpcOptions& options,
Expand Down Expand Up @@ -102,13 +101,21 @@ class RocketChannelWrapper : public apache::thrift::RequestChannel {
}

RocketChannelWrapper(apache::thrift::RocketClientChannel::Ptr channel,
folly::EventBase* evb)
: channel_(std::move(channel)), evb_(evb) {
folly::EventBase* evb,
folly::Executor* callback_executor)
: channel_(std::move(channel)),
evb_(evb),
callback_executor_(callback_executor) {
ld_check(channel_);
ld_check(evb_);
}

template <bool oneWayCb>
apache::thrift::RequestClientCallback::Ptr
wrapIfUnsafe(apache::thrift::RequestClientCallback::Ptr cob);

apache::thrift::RocketClientChannel::Ptr channel_;
folly::EventBase* evb_;
folly::Executor* callback_executor_;
};
}}} // namespace facebook::logdevice::detail
23 changes: 13 additions & 10 deletions logdevice/common/thrift/SimpleThriftClientFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,22 @@ class SimpleThriftClientFactory : public ThriftClientFactory {

protected:
ThriftClientFactory::ChannelPtr
createChannel(const folly::SocketAddress& address) override {
createChannel(const folly::SocketAddress& address,
folly::Executor* callback_executor) override {
// Get random evb for this client
auto evb = io_executor_.getEventBase();
ThriftClientFactory::ChannelPtr channel;
evb->runInEventBaseThreadAndWait([address, evb, &channel, this]() {
AsyncSocket::UniquePtr socket(
new AsyncSocket(evb, address, connect_timeout_.count()));
auto rocket = RocketClientChannel::newChannel(std::move(socket));
if (request_timeout_.count() > 0) {
rocket->setTimeout(request_timeout_.count());
}
channel = RocketChannelWrapper::newChannel(std::move(rocket), evb);
});
evb->runInEventBaseThreadAndWait(
[address, evb, &channel, this, callback_executor]() {
AsyncSocket::UniquePtr socket(
new AsyncSocket(evb, address, connect_timeout_.count()));
auto rocket = RocketClientChannel::newChannel(std::move(socket));
if (request_timeout_.count() > 0) {
rocket->setTimeout(request_timeout_.count());
}
channel = RocketChannelWrapper::newChannel(
std::move(rocket), evb, callback_executor);
});
return channel;
}

Expand Down
10 changes: 7 additions & 3 deletions logdevice/common/thrift/ThriftClientFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <memory>

#include <folly/Executor.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/DelayedDestruction.h>

Expand All @@ -35,8 +36,10 @@ class ThriftClientFactory {
* @return Pointer to new client.
*/
template <typename T>
std::unique_ptr<T> createClient(const folly::SocketAddress& address) {
ChannelPtr channel = createChannel(address);
std::unique_ptr<T>
createClient(const folly::SocketAddress& address,
folly::Executor* callback_executor = nullptr) {
ChannelPtr channel = createChannel(address, callback_executor);
return std::make_unique<T>(std::move(channel));
}

Expand All @@ -51,7 +54,8 @@ class ThriftClientFactory {
*
* @param address Address of the Thrift server to connect to.
*/
virtual ChannelPtr createChannel(const folly::SocketAddress& address) = 0;
virtual ChannelPtr createChannel(const folly::SocketAddress& address,
folly::Executor* callback_executor) = 0;
};

}} // namespace facebook::logdevice
4 changes: 3 additions & 1 deletion logdevice/common/thrift/ThriftRouter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "logdevice/common/thrift/ThriftRouter.h"

#include "logdevice/common/Worker.h"
#include "logdevice/common/configuration/nodes/ServerAddressRouter.h"
#include "logdevice/common/debug.h"
#include "logdevice/common/if/gen-cpp2/LogDeviceAPIAsyncClient.h"
Expand Down Expand Up @@ -44,7 +45,8 @@ NcmThriftRouter::getApiClient(node_index_t nid) {
nid);
return nullptr;
}
auto callback_executor = Worker::onThisThread(/*enforce_worker*/ false);
return client_factory_->createClient<LogDeviceAPIAsyncClient>(
maybe_address->getSocketAddress());
maybe_address->getSocketAddress(), callback_executor);
}
}} // namespace facebook::logdevice
6 changes: 5 additions & 1 deletion logdevice/common/thrift/ThriftRouter.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <memory>

#include <folly/Executor.h>

#include "logdevice/common/NodeID.h"
#include "logdevice/common/configuration/Configuration.h"
#include "logdevice/common/settings/Settings.h"
Expand All @@ -35,10 +37,12 @@ class ThriftRouter {
/**
* Creates a new client for Thrift API on the node with given ID.
*
* @param nid ID of the node which will be used as a destination
* for all Thrift requests on the client
* @return New Thrift client or nullptr if unable to router.
*/
virtual std::unique_ptr<thrift::LogDeviceAPIAsyncClient>
getApiClient(node_index_t) = 0;
getApiClient(node_index_t nid) = 0;

virtual ~ThriftRouter() = default;
};
Expand Down
65 changes: 59 additions & 6 deletions logdevice/test/ThriftApiIntegrationTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "logdevice/test/utils/IntegrationTestUtils.h"

using namespace ::testing;
using apache::thrift::ClientReceiveState;
using facebook::fb303::cpp2::fb_status;
using facebook::logdevice::IntegrationTestUtils::ClusterFactory;

Expand All @@ -40,27 +41,79 @@ class ThriftApiIntegrationTestBase : public IntegrationTestBase {
}

protected:
using ThriftClient = thrift::LogDeviceAPIAsyncClient;
using ThriftClientPtr = std::unique_ptr<thrift::LogDeviceAPIAsyncClient>;
// This test checks that Thirft server starts and we are able to make an RPC
// request to it from within Worker
void checkSingleRpcCall() {
auto status = runWithClient([](thrift::LogDeviceAPIAsyncClient& client) {
return client.semifuture_getStatus().get();
auto status = runWithClient([](ThriftClientPtr client) {
auto cb = std::make_unique<StatusCb>(std::move(client));
auto client_ptr = cb->client_ptr.get();
auto future = cb->responsePromise.getSemiFuture();
client_ptr->getStatus(std::move(cb));
return future;
});
ASSERT_EQ(fb_status::ALIVE, status);
ASSERT_EQ(fb_status::ALIVE, std::move(status).get());
}

class StatusCb : public apache::thrift::SendRecvRequestCallback {
public:
StatusCb(ThriftClientPtr client)
: client_ptr(std::move(client)),
original_worker(Worker::onThisThread()){};

bool ensureWorker() {
if (Worker::onThisThread(false) != original_worker) {
responsePromise.setException(
std::runtime_error("Callback is called either not in Worker "
"context or by wrong worker"));
return false;
}
return true;
}

void send(folly::exception_wrapper&& ex) override {
if (!ensureWorker()) {
return;
}
if (ex) {
responsePromise.setException(std::move(ex));
}
}

void recv(ClientReceiveState&& state) override {
if (!ensureWorker()) {
return;
}
if (state.isException()) {
responsePromise.setException(std::move(state.exception()));
} else {
fb_status response;
auto exception = ThriftClient::recv_wrapped_getStatus(response, state);
if (exception) {
responsePromise.setException(std::move(exception));
} else {
responsePromise.setValue(response);
}
}
}

ThriftClientPtr client_ptr;
folly::Promise<fb_status> responsePromise;
Worker* original_worker;
};

private:
// Runs arbitrary code on client's worker thread and provides Thrift client
// created on this worker
template <typename Func>
typename std::result_of<Func(thrift::LogDeviceAPIAsyncClient&)>::type
runWithClient(Func cb) {
typename std::result_of<Func(ThriftClientPtr)>::type runWithClient(Func cb) {
ClientImpl* impl = checked_downcast<ClientImpl*>(client_.get());
return run_on_worker(&(impl->getProcessor()), 0, [&]() {
auto worker = Worker::onThisThread();
node_index_t nid{0};
auto thrift_client = worker->getThriftRouter()->getApiClient(nid);
return cb(*thrift_client);
return cb(std::move(thrift_client));
});
}

Expand Down

0 comments on commit a8fe5d2

Please sign in to comment.