Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify AddSpilledUrl into UpdateObjectLocationBatch RPC #23872

Merged
merged 13 commits into from
Apr 18, 2022
6 changes: 0 additions & 6 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,6 @@ class MockCoreWorker : public CoreWorker {
rpc::SpillObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleAddSpilledUrl,
(const rpc::AddSpilledUrlRequest &request,
rpc::AddSpilledUrlReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleRestoreSpilledObjects,
(const rpc::RestoreSpilledObjectsRequest &request,
Expand Down
5 changes: 0 additions & 5 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientIn
(const DeleteSpilledObjectsRequest &request,
const ClientCallback<DeleteSpilledObjectsReply> &callback),
(override));
MOCK_METHOD(void,
AddSpilledUrl,
(const AddSpilledUrlRequest &request,
const ClientCallback<AddSpilledUrlReply> &callback),
(override));
MOCK_METHOD(void,
PlasmaObjectReady,
(const PlasmaObjectReadyRequest &request,
Expand Down
53 changes: 26 additions & 27 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2886,19 +2886,23 @@ void CoreWorker::HandleUpdateObjectLocationBatch(
return;
}
const auto &node_id = NodeID::FromBinary(request.node_id());
const auto &object_location_states = request.object_location_states();
const auto &object_location_updates = request.object_location_updates();

for (const auto &object_location_state : object_location_states) {
const auto &object_id = ObjectID::FromBinary(object_location_state.object_id());
const auto &state = object_location_state.state();
for (const auto &object_location_update : object_location_updates) {
const auto &object_id = ObjectID::FromBinary(object_location_update.object_id());

if (state == rpc::ObjectLocationState::ADDED) {
if (object_location_update.has_spilled_url()) {
AddSpilledObjectLocationOwner(
jjyao marked this conversation as resolved.
Show resolved Hide resolved
object_id,
object_location_update.spilled_url(),
NodeID::FromBinary(object_location_update.spilled_node_id()));
}

if (object_location_update.has_in_plasma() && object_location_update.in_plasma()) {
jjyao marked this conversation as resolved.
Show resolved Hide resolved
AddObjectLocationOwner(object_id, node_id);
} else if (state == rpc::ObjectLocationState::REMOVED) {
} else if (object_location_update.has_in_plasma() &&
!object_location_update.in_plasma()) {
RemoveObjectLocationOwner(object_id, node_id);
} else {
jjyao marked this conversation as resolved.
Show resolved Hide resolved
RAY_LOG(FATAL) << "Invalid object location state " << state
<< " has been received.";
}
}

Expand All @@ -2907,6 +2911,19 @@ void CoreWorker::HandleUpdateObjectLocationBatch(
/*failure_callback_on_reply*/ nullptr);
}

void CoreWorker::AddSpilledObjectLocationOwner(const ObjectID &object_id,
const std::string &spilled_url,
const NodeID &spilled_node_id) {
RAY_LOG(DEBUG) << "Received object spilled location update for object " << object_id
<< ", which has been spilled to " << spilled_url << " on node "
<< spilled_node_id;
auto reference_exists =
reference_counter_->HandleObjectSpilled(object_id, spilled_url, spilled_node_id);
if (!reference_exists) {
RAY_LOG(DEBUG) << "Object " << object_id << " not found";
}
}

void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id,
const NodeID &node_id) {
if (gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/true) == nullptr) {
Expand Down Expand Up @@ -3158,24 +3175,6 @@ void CoreWorker::HandleSpillObjects(const rpc::SpillObjectsRequest &request,
}
}

void CoreWorker::HandleAddSpilledUrl(const rpc::AddSpilledUrlRequest &request,
rpc::AddSpilledUrlReply *reply,
rpc::SendReplyCallback send_reply_callback) {
const ObjectID object_id = ObjectID::FromBinary(request.object_id());
const std::string &spilled_url = request.spilled_url();
const NodeID node_id = NodeID::FromBinary(request.spilled_node_id());
RAY_LOG(DEBUG) << "Received AddSpilledUrl request for object " << object_id
<< ", which has been spilled to " << spilled_url << " on node "
<< node_id;
auto reference_exists = reference_counter_->HandleObjectSpilled(
object_id, spilled_url, node_id, request.size());
Status status =
reference_exists
? Status::OK()
: Status::ObjectNotFound("Object " + object_id.Hex() + " not found");
send_reply_callback(status, nullptr, nullptr);
}

void CoreWorker::HandleRestoreSpilledObjects(
const rpc::RestoreSpilledObjectsRequest &request,
rpc::RestoreSpilledObjectsReply *reply,
Expand Down
9 changes: 4 additions & 5 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
rpc::SpillObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

// Add spilled URL to owned reference.
void HandleAddSpilledUrl(const rpc::AddSpilledUrlRequest &request,
rpc::AddSpilledUrlReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

// Restore objects from external storage.
void HandleRestoreSpilledObjects(const rpc::RestoreSpilledObjectsRequest &request,
rpc::RestoreSpilledObjectsReply *reply,
Expand Down Expand Up @@ -1002,6 +997,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// messages.
void ProcessPubsubCommands(const Commands &commands, const NodeID &subscriber_id);

void AddSpilledObjectLocationOwner(const ObjectID &object_id,
const std::string &spilled_url,
const NodeID &spilled_node_id);

void AddObjectLocationOwner(const ObjectID &object_id, const NodeID &node_id);

void RemoveObjectLocationOwner(const ObjectID &object_id, const NodeID &node_id);
Expand Down
15 changes: 1 addition & 14 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,19 +1202,9 @@ absl::optional<absl::flat_hash_set<NodeID>> ReferenceCounter::GetObjectLocations
return it->second.locations;
}

size_t ReferenceCounter::GetObjectSize(const ObjectID &object_id) const {
absl::MutexLock lock(&mutex_);
auto it = object_id_refs_.find(object_id);
if (it == object_id_refs_.end()) {
return 0;
}
return it->second.object_size;
}

bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id,
const std::string spilled_url,
const NodeID &spilled_node_id,
int64_t size) {
const NodeID &spilled_node_id) {
absl::MutexLock lock(&mutex_);
auto it = object_id_refs_.find(object_id);
if (it == object_id_refs_.end()) {
Expand All @@ -1239,9 +1229,6 @@ bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id,
if (!spilled_node_id.IsNil()) {
it->second.spilled_node_id = spilled_node_id;
}
if (size > 0) {
it->second.object_size = size;
jjyao marked this conversation as resolved.
Show resolved Hide resolved
}
PushToLocationSubscribers(it);
} else {
RAY_LOG(DEBUG) << "Object " << object_id << " spilled to dead node "
Expand Down
10 changes: 1 addition & 9 deletions src/ray/core_worker/reference_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,25 +431,17 @@ class ReferenceCounter : public ReferenceCounterInterface,
rpc::WorkerObjectLocationsPubMessage *object_info)
LOCKS_EXCLUDED(mutex_);

/// Get an object's size. This will return 0 if the object is out of scope.
///
/// \param[in] object_id The object whose size to get.
/// \return Object size, or 0 if the object is out of scope.
size_t GetObjectSize(const ObjectID &object_id) const;

/// Handle an object has been spilled to external storage.
///
/// This notifies the primary raylet that the object is safe to release and
/// records the spill URL, spill node ID, and updated object size.
/// \param[in] object_id The object that has been spilled.
/// \param[in] spilled_url The URL to which the object has been spilled.
/// \param[in] spilled_node_id The ID of the node on which the object was spilled.
/// \param[in] size The size of the object.
/// \return True if the reference exists and is in scope, false otherwise.
bool HandleObjectSpilled(const ObjectID &object_id,
const std::string spilled_url,
const NodeID &spilled_node_id,
int64_t size);
const NodeID &spilled_node_id);

/// Get locality data for object. This is used by the leasing policy to implement
/// locality-aware leasing.
Expand Down
64 changes: 58 additions & 6 deletions src/ray/core_worker/test/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ class CoreWorkerTest : public ::testing::Test {
// start raylet on each node. Assign each node with different resources so that
// a task can be scheduled to the desired node.
for (int i = 0; i < num_nodes; i++) {
raylet_socket_names_[i] =
TestSetupUtil::StartRaylet("127.0.0.1",
node_manager_port + i,
"127.0.0.1:6379",
"\"CPU,4.0,resource" + std::to_string(i) + ",10\"",
&raylet_store_socket_names_[i]);
raylet_socket_names_[i] = TestSetupUtil::StartRaylet(
"127.0.0.1",
node_manager_port + i,
"127.0.0.1:6379",
"\"CPU,4.0,object_store_memory,100,resource" + std::to_string(i) + ",10\"",
&raylet_store_socket_names_[i]);
}
}

Expand Down Expand Up @@ -877,6 +877,58 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
ASSERT_TRUE(results[1]->IsException());
}

TEST_F(SingleNodeTest, TestHandleUpdateObjectLocationBatch) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks for adding this test, but I think the existing Python tests are probably enough for this PR... I've actually been meaning to deprecate this test suite since the maintainability is not really worth the test coverage compared to e2e integration tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also iirc, this test doesn't even run in the master CI!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I also realized that. Also the current core_worker_test is like semi integration test, it needs to launch gcs, raylet, etc. We should try to make it more unit-testable via techniques like dependency ingestion.

auto &driver = CoreWorkerProcess::GetCoreWorker();
auto buffer = GenerateRandomBuffer();

ObjectID object_id;
RAY_CHECK_OK(driver.Put(
RayObject(buffer, nullptr, std::vector<rpc::ObjectReference>()), {}, &object_id));
rpc::UpdateObjectLocationBatchRequest request;
rpc::UpdateObjectLocationBatchReply reply;
request.set_intended_worker_id(driver.GetWorkerID().Binary());
request.set_node_id(driver.GetCurrentNodeId().Binary());
auto update = request.add_object_location_updates();
update->set_object_id(object_id.Binary());
update->set_spilled_url("url1");
update->set_spilled_node_id(driver.GetCurrentNodeId().Binary());
update->set_in_plasma(false);
driver.HandleUpdateObjectLocationBatch(
request,
&reply,
[](Status status, std::function<void()> success, std::function<void()> failure) {});

rpc::GetObjectLocationsOwnerRequest get_request;
get_request.mutable_object_location_request()->set_intended_worker_id(
driver.GetWorkerID().Binary());
get_request.mutable_object_location_request()->set_object_id(object_id.Binary());
rpc::GetObjectLocationsOwnerReply get_reply;
driver.HandleGetObjectLocationsOwner(
get_request,
&get_reply,
[](Status status, std::function<void()> success, std::function<void()> failure) {});
ASSERT_EQ(get_reply.object_location_info().node_ids().size(), 0);
ASSERT_EQ(get_reply.object_location_info().spilled_url(), "url1");
ASSERT_EQ(get_reply.object_location_info().spilled_node_id(),
driver.GetCurrentNodeId().Binary());

request.clear_object_location_updates();
update = request.add_object_location_updates();
update->set_object_id(object_id.Binary());
update->set_in_plasma(true);
driver.HandleUpdateObjectLocationBatch(
request,
&reply,
[](Status status, std::function<void()> success, std::function<void()> failure) {});
driver.HandleGetObjectLocationsOwner(
get_request,
&get_reply,
[](Status status, std::function<void()> success, std::function<void()> failure) {});
ASSERT_EQ(get_reply.object_location_info().node_ids().size(), 1);
ASSERT_EQ(get_reply.object_location_info().node_ids(0),
driver.GetCurrentNodeId().Binary());
}

TEST_F(SingleNodeTest, TestNormalTaskLocal) { TestNormalTask(); }

TEST_F(SingleNodeTest, TestCancelTasks) {
Expand Down
25 changes: 25 additions & 0 deletions src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,31 @@ TEST_F(ReferenceCountTest, TestReferenceStats) {
rc->RemoveLocalReference(id2, nullptr);
}

TEST_F(ReferenceCountTest, TestHandleObjectSpilled) {
ObjectID obj1 = ObjectID::FromRandom();
NodeID node1 = NodeID::FromRandom();
rpc::Address address;
address.set_ip_address("1234");

int64_t object_size = 100;
rc->AddOwnedObject(obj1,
{},
address,
"file1.py:42",
object_size,
false,
/*add_local_ref=*/true,
absl::optional<NodeID>(node1));
rc->HandleObjectSpilled(obj1, "url1", node1);
rpc::WorkerObjectLocationsPubMessage object_info;
Status status = rc->FillObjectInformation(obj1, &object_info);
ASSERT_TRUE(status.ok());
ASSERT_EQ(object_info.object_size(), object_size);
ASSERT_EQ(object_info.spilled_url(), "url1");
ASSERT_EQ(object_info.spilled_node_id(), node1.Binary());
rc->RemoveLocalReference(obj1, nullptr);
}

// Tests fetching of locality data from reference table.
TEST_F(ReferenceCountTest, TestGetLocalityData) {
ObjectID obj1 = ObjectID::FromRandom();
Expand Down
6 changes: 6 additions & 0 deletions src/ray/object_manager/object_directory.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class IObjectDirectory {
const NodeID &node_id,
const ObjectInfo &object_info) = 0;

virtual void ReportObjectSpilled(const ObjectID &object_id,
const NodeID &node_id,
const rpc::Address &owner_address,
const std::string &spilled_url,
const NodeID &spilled_node_id) = 0;

/// Record metrics.
virtual void RecordMetrics(uint64_t duration_ms) = 0;

Expand Down
Loading