Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[direct call] changes raylet to push tasks to worker #5140

Merged
merged 32 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7f635b4
refactor grpc server
zhijunfu Jun 24, 2019
2205c87
format
zhijunfu Jun 24, 2019
0703b75
change GetTask() to PushTask()
zhijunfu Jun 24, 2019
6d5228a
merge code, and fix test
zhijunfu Jun 25, 2019
4c3eb6e
merge
zhijunfu Jun 26, 2019
d1b20e5
change PushTask to AssignTask
zhijunfu Jun 26, 2019
1a663bc
format
zhijunfu Jun 26, 2019
4ef00b2
merge
zhijunfu Jul 8, 2019
6caa126
add resource_ids
zhijunfu Jul 8, 2019
302e79f
move done_callback to server call
zhijunfu Jul 8, 2019
70d6e26
remove SetTaskHandler and initialize it in task receiver's constructor
zhijunfu Jul 8, 2019
17b22e4
format
zhijunfu Jul 8, 2019
dbf666c
resolve comments
zhijunfu Jul 8, 2019
a5f301d
update
zhijunfu Jul 8, 2019
0c762f1
update
zhijunfu Jul 8, 2019
819614f
Update src/ray/core_worker/core_worker.cc
zhijunfu Jul 9, 2019
beb01be
resolve comments
zhijunfu Jul 9, 2019
be05ed6
merge
zhijunfu Jul 9, 2019
878cae4
format
zhijunfu Jul 9, 2019
722511b
Merge branch 'worker_grpc' of https://github.com/ant-tech-alliance/ra…
zhijunfu Jul 9, 2019
3514156
merge
zhijunfu Jul 9, 2019
bddc209
Update src/ray/core_worker/transport/raylet_transport.cc
zhijunfu Jul 9, 2019
78bfc24
resolve comments
zhijunfu Jul 9, 2019
df2a6ff
Merge branch 'worker_grpc' of https://github.com/ant-tech-alliance/ra…
zhijunfu Jul 9, 2019
32b3d3e
resolve comments
zhijunfu Jul 10, 2019
d157b84
fix build
zhijunfu Jul 10, 2019
42179ea
format
zhijunfu Jul 10, 2019
20cbf1f
Merge branch 'master' of https://github.com/ray-project/ray into work…
zhijunfu Jul 10, 2019
85f2d73
merge
zhijunfu Jul 11, 2019
6600779
fix
raulchen Jul 11, 2019
5e66995
format
raulchen Jul 11, 2019
b60574d
noop
stephanie-wang Jul 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace ray {
CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id,
CoreWorkerTaskExecutionInterface::TaskExecutor execution_callback)
const CoreWorkerTaskExecutionInterface::TaskExecutor &execution_callback)
: worker_type_(worker_type),
language_(language),
raylet_socket_(raylet_socket),
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CoreWorker {
CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id,
CoreWorkerTaskExecutionInterface::TaskExecutor execution_callback = nullptr);
const CoreWorkerTaskExecutionInterface::TaskExecutor &execution_callback);

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }
Expand Down
12 changes: 6 additions & 6 deletions src/ray/core_worker/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class CoreWorkerTest : public ::testing::Test {

void TestNormalTask(const std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], JobID::FromRandom());
raylet_socket_names_[0], JobID::FromRandom(), nullptr);

// Test pass by value.
{
Expand Down Expand Up @@ -184,7 +184,7 @@ class CoreWorkerTest : public ::testing::Test {

void TestActorTask(const std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], JobID::FromRandom());
raylet_socket_names_[0], JobID::FromRandom(), nullptr);

std::unique_ptr<ActorHandle> actor_handle;

Expand Down Expand Up @@ -335,7 +335,7 @@ TEST_F(ZeroNodeTest, TestActorHandle) {
TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromRandom());
JobID::FromRandom(), nullptr);

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -398,10 +398,10 @@ TEST_F(SingleNodeTest, TestObjectInterface) {

TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
CoreWorker worker1(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], JobID::FromRandom());
raylet_socket_names_[0], JobID::FromRandom(), nullptr);

CoreWorker worker2(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[1],
raylet_socket_names_[1], JobID::FromRandom());
raylet_socket_names_[1], JobID::FromRandom(), nullptr);

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -487,7 +487,7 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) {
TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) {
try {
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "",
raylet_socket_names_[0], JobID::FromRandom());
raylet_socket_names_[0], JobID::FromRandom(), nullptr);
} catch (const std::exception &e) {
std::cout << "Caught exception when constructing core worker: " << e.what();
}
Expand Down
61 changes: 31 additions & 30 deletions src/ray/core_worker/mock_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "ray/core_worker/store_provider/store_provider.h"
#include "ray/core_worker/task_execution.h"

using namespace std::placeholders;

namespace ray {

/// A mock C++ worker used by core_worker_test.cc to verify the task submission/execution
Expand All @@ -16,43 +18,42 @@ namespace ray {
/// for more details on how this class is used.
class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket) {
auto executor_func = [this](const RayFunction &ray_function,
const std::vector<std::shared_ptr<RayObject>> &args,
const TaskInfo &task_info, int num_returns) {
// Note that this doesn't include dummy object id.
RAY_CHECK(num_returns >= 0);
MockWorker(const std::string &store_socket, const std::string &raylet_socket)
: worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromRandom(),
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4)) {}

// Merge all the content from input args.
std::vector<uint8_t> buffer;
for (const auto &arg : args) {
auto &data = arg->GetData();
buffer.insert(buffer.end(), data->Data(), data->Data() + data->Size());
}
void Run() {
// Start executing tasks.
worker_.Execution().Run();
}

auto return_value = RayObject(
std::make_shared<LocalMemoryBuffer>(buffer.data(), buffer.size()), nullptr);
private:
Status ExecuteTask(const RayFunction &ray_function,
const std::vector<std::shared_ptr<RayObject>> &args,
const TaskInfo &task_info, int num_returns) {
// Note that this doesn't include dummy object id.
RAY_CHECK(num_returns >= 0);

// Write the merged content to each of return ids.
for (int i = 0; i < num_returns; i++) {
ObjectID id = ObjectID::ForTaskReturn(task_info.task_id, i + 1);
RAY_CHECK_OK(worker_->Objects().Put(return_value, id));
}
return Status::OK();
};
// Merge all the content from input args.
std::vector<uint8_t> buffer;
for (const auto &arg : args) {
auto &data = arg->GetData();
buffer.insert(buffer.end(), data->Data(), data->Data() + data->Size());
}

worker_ = std::unique_ptr<CoreWorker>(
new CoreWorker(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromRandom(), executor_func));
}
auto return_value = RayObject(
std::make_shared<LocalMemoryBuffer>(buffer.data(), buffer.size()), nullptr);

void Run() {
// Start executing tasks.
worker_->Execution().Run();
// Write the merged content to each of return ids.
for (int i = 0; i < num_returns; i++) {
ObjectID id = ObjectID::ForTaskReturn(task_info.task_id, i + 1);
RAY_CHECK_OK(worker_.Objects().Put(return_value, id));
}
return Status::OK();
}

private:
std::unique_ptr<CoreWorker> worker_;
CoreWorker worker_;
};

} // namespace ray
Expand Down
8 changes: 4 additions & 4 deletions src/ray/core_worker/task_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface(
std::placeholders::_1);
task_receivers_.emplace(
static_cast<int>(TaskTransportType::RAYLET),
std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(main_service_, worker_server_, func)));
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
raylet_client, main_service_, worker_server_, func)));

// Start RPC server after all the task receivers are properly initialized.
worker_server_.Run();
Expand Down Expand Up @@ -60,12 +60,12 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask(
return status;
}

Status CoreWorkerTaskExecutionInterface::Run() {
void CoreWorkerTaskExecutionInterface::Run() {
// Run main IO service.
main_service_.run();

// should never reach here.
return Status::OK();
RAY_LOG(FATAL) << "should never reach here after running main io service";
}

Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor(
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/task_execution.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class CoreWorkerTaskExecutionInterface {
const TaskExecutor &executor);

/// Start receving and executes tasks in a infinite loop.
/// \return Status.
Status Run();
/// \return void.
void Run();

private:
/// Build arguments for task executor. This would loop through all the arguments
Expand Down
14 changes: 11 additions & 3 deletions src/ray/core_worker/transport/raylet_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpec &task) {
}

CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
boost::asio::io_service &io_service, rpc::GrpcServer &server,
const TaskHandler &task_handler)
: task_service_(io_service, *this), task_handler_(task_handler) {
std::unique_ptr<RayletClient> &raylet_client, boost::asio::io_service &io_service,
rpc::GrpcServer &server, const TaskHandler &task_handler)
: raylet_client_(raylet_client),
task_service_(io_service, *this),
task_handler_(task_handler) {
server.RegisterService(task_service_);
}

Expand All @@ -26,6 +28,12 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const raylet::Task task(request.task());
const auto &spec = task.GetTaskSpecification();
auto status = task_handler_(spec);
// Notify raylet the current task is done. This is to ensure that the task
// is marked as finished by raylet only after previous raylet client calls are
// completed. The rpc `done_callback` is sent via a different connection
// from raylet client connection, so it cannot guarantee the rpc reply arrives
// at raylet after a previous `NotifyUnblocked` message.
zhijunfu marked this conversation as resolved.
Show resolved Hide resolved
raylet_client_->TaskDone();
send_reply_callback(status, nullptr, nullptr);
}

Expand Down
5 changes: 4 additions & 1 deletion src/ray/core_worker/transport/raylet_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter {
class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver,
public rpc::WorkerTaskHandler {
public:
CoreWorkerRayletTaskReceiver(boost::asio::io_service &io_service,
CoreWorkerRayletTaskReceiver(std::unique_ptr<RayletClient> &raylet_client,
boost::asio::io_service &io_service,
rpc::GrpcServer &server, const TaskHandler &task_handler);

/// Handle a `AssignTask` request.
Expand All @@ -46,6 +47,8 @@ class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver,
rpc::SendReplyCallback send_reply_callback) override;

private:
/// Raylet client.
std::unique_ptr<RayletClient> &raylet_client_;
/// The rpc service for `WorkerTaskService`.
rpc::WorkerTaskGrpcService task_service_;
/// The callback function to process a task.
Expand Down
12 changes: 3 additions & 9 deletions src/ray/protobuf/worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";

message AssignTaskRequest {
// The ID of the task to be pushed.
bytes task_id = 1;
// The task to be assigned. This should include task_id.
// TODO(hchen): Currently, `task_spec` are represented as
// flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
// structure is being used in many places. We should move Task and all related data
// structures to protobuf.
bytes task_spec = 2;
// The task to be pushed.
Task task = 1;
// A list of the resources reserved for this worker.
// TODO(zhijunfu): `resource_ids` is represented as
// flatbutters-serialized bytes, will be moved to protobuf later.
bytes resource_ids = 3;
bytes resource_ids = 2;
}

message AssignTaskReply {
Expand Down
Loading