Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Destruct reply of GRPC as early as possible #14598

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ class ClientCallImpl : public ClientCall {
/// Constructor.
///
/// \param[in] callback The callback function to handle the reply.
explicit ClientCallImpl(const ClientCallback<Reply> &callback) : callback_(callback) {}
explicit ClientCallImpl(
const std::function<void(const Status &status, std::shared_ptr<Reply>)> &callback)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not update the definition of ClientCallback?

: reply_(new Reply), callback_(callback) {}

Status GetStatus() override {
absl::MutexLock lock(&mutex_);
Expand All @@ -81,16 +83,16 @@ class ClientCallImpl : public ClientCall {
status = return_status_;
}
if (callback_ != nullptr) {
callback_(status, reply_);
callback_(status, std::move(reply_));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving a shared_ptr seems unnecessary.

}
}

private:
/// The reply message.
Reply reply_;
std::shared_ptr<Reply> reply_;

/// The callback function to handle the reply.
ClientCallback<Reply> callback_;
std::function<void(const Status &status, std::shared_ptr<Reply>)> callback_;

/// The response reader.
std::unique_ptr<grpc_impl::ClientAsyncResponseReader<Reply>> response_reader_;
Expand Down Expand Up @@ -204,7 +206,13 @@ class ClientCallManager {
typename GrpcService::Stub &stub,
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
const Request &request, const ClientCallback<Reply> &callback) {
auto call = std::make_shared<ClientCallImpl<Reply>>(callback);
auto call = std::make_shared<ClientCallImpl<Reply>>(
[this, callback](const Status &status, std::shared_ptr<Reply> reply) {
if (callback && !main_service_.stopped() && !shutdown_) {
main_service_.post([status, reply, callback] { callback(status, *reply); });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that you copied the std::shared_ptr<Reply> here. I don't know what's the optimization here.

}
});

// Send request.
// Find the next completion queue to wait for response.
call->response_reader_ = (stub.*prepare_async_function)(
Expand All @@ -218,7 +226,7 @@ class ClientCallManager {
// `ClientCall` is safe to use. But `response_reader_->Finish` only accepts a raw
// pointer.
auto tag = new ClientCallTag(call);
call->response_reader_->Finish(&call->reply_, &call->status_, (void *)tag);
call->response_reader_->Finish(call->reply_.get(), &call->status_, (void *)tag);
return call;
}

Expand Down Expand Up @@ -248,16 +256,10 @@ class ClientCallManager {
} else if (status != grpc::CompletionQueue::TIMEOUT) {
auto tag = reinterpret_cast<ClientCallTag *>(got_tag);
tag->GetCall()->SetReturnStatus();
if (ok && !main_service_.stopped() && !shutdown_) {
// Post the callback to the main event loop.
main_service_.post([tag]() {
tag->GetCall()->OnReplyReceived();
// The call is finished, and we can delete this tag now.
delete tag;
});
} else {
delete tag;
if (ok) {
tag->GetCall()->OnReplyReceived();
}
delete tag;
}
}
}
Expand Down
18 changes: 8 additions & 10 deletions src/ray/rpc/server_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class ServerCallImpl : public ServerCall {
// Handle service for rpc call has stopped, we must handle the call here
// to send reply and remove it from cq
RAY_LOG(DEBUG) << "Handle service has been closed.";
SendReply(Status::Invalid("HandleServiceClosed"));
SendReply(Reply(), Status::Invalid("HandleServiceClosed"));
}
}

Expand All @@ -157,10 +157,11 @@ class ServerCallImpl : public ServerCall {
// We create this before handling the request so that the it can be populated by
// the completion queue in the background if a new request comes in.
factory.CreateCall();
auto reply = std::make_shared<Reply>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document that we want the reply to be destructed once it's sent instead of being destructed on the next request arriving?

(service_handler_.*handle_request_function_)(
request_, &reply_,
[this](Status status, std::function<void()> success,
std::function<void()> failure) {
request_, reply.get(),
[this, reply](Status status, std::function<void()> success,
std::function<void()> failure) {
// These two callbacks must be set before `SendReply`, because `SendReply`
// is async and this `ServerCall` might be deleted right after `SendReply`.
send_reply_success_callback_ = std::move(success);
Expand All @@ -169,7 +170,7 @@ class ServerCallImpl : public ServerCall {
// When the handler is done with the request, tell gRPC to finish this request.
// Must send reply at the bottom of this callback, once we invoke this funciton,
// this server call might be deleted
SendReply(status);
SendReply(*reply, status);
});
}

Expand All @@ -189,9 +190,9 @@ class ServerCallImpl : public ServerCall {

private:
/// Tell gRPC to finish this request and send reply asynchronously.
void SendReply(const Status &status) {
void SendReply(const Reply &reply, const Status &status) {
state_ = ServerCallState::SENDING_REPLY;
response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this);
response_writer_.Finish(reply, RayStatusToGrpcStatus(status), this);
}

/// State of this call.
Expand Down Expand Up @@ -219,9 +220,6 @@ class ServerCallImpl : public ServerCall {
/// The request message.
Request request_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not optimize request as well?


/// The reply message.
Reply reply_;

/// The callback when sending reply successes.
std::function<void()> send_reply_success_callback_ = nullptr;

Expand Down