Skip to content

Commit

Permalink
Revert "Revert "[core] Fix bugs in data locality (ray-project#24698)" (
Browse files Browse the repository at this point in the history
…ray-project#25035)"

This reverts commit 916c679.
  • Loading branch information
stephanie-wang committed May 23, 2022
1 parent 50d49a2 commit 4942d7f
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 10 deletions.
44 changes: 44 additions & 0 deletions python/ray/tests/test_scheduling_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,50 @@ def check_backlog_info():
cluster.shutdown()


def test_data_locality_spilled_objects(
ray_start_cluster_enabled, fs_only_object_spilling_config
):
cluster = ray_start_cluster_enabled
object_spilling_config, _ = fs_only_object_spilling_config
cluster.add_node(
num_cpus=1,
object_store_memory=100 * 1024 * 1024,
_system_config={
"min_spilling_size": 1,
"object_spilling_config": object_spilling_config,
},
)
ray.init(cluster.address)
cluster.add_node(
num_cpus=1, object_store_memory=100 * 1024 * 1024, resources={"remote": 1}
)

@ray.remote(resources={"remote": 1})
def f():
return (
np.zeros(50 * 1024 * 1024, dtype=np.uint8),
ray.runtime_context.get_runtime_context().node_id,
)

@ray.remote
def check_locality(x):
_, node_id = x
assert node_id == ray.runtime_context.get_runtime_context().node_id

# Check locality works when dependent task is already submitted by the time
# the upstream task finishes.
for _ in range(5):
ray.get(check_locality.remote(f.remote()))

# Check locality works when some objects were spilled.
xs = [f.remote() for _ in range(5)]
ray.wait(xs, num_returns=len(xs), fetch_local=False)
for i, x in enumerate(xs):
task = check_locality.remote(x)
print(i, x, task)
ray.get(task)


if __name__ == "__main__":
import pytest

Expand Down
2 changes: 1 addition & 1 deletion src/mock/ray/core_worker/lease_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MockLocalityDataProviderInterface : public LocalityDataProviderInterface {
MOCK_METHOD(absl::optional<LocalityData>,
GetLocalityData,
(const ObjectID &object_id),
(override));
(const override));
};

} // namespace core
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/lease_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ struct LocalityData {
/// Interface for providers of locality data to the lease policy.
class LocalityDataProviderInterface {
public:
virtual absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) = 0;
virtual absl::optional<LocalityData> GetLocalityData(
const ObjectID &object_id) const = 0;

virtual ~LocalityDataProviderInterface() {}
};
Expand Down
7 changes: 5 additions & 2 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id,
}

absl::optional<LocalityData> ReferenceCounter::GetLocalityData(
const ObjectID &object_id) {
const ObjectID &object_id) const {
absl::MutexLock lock(&mutex_);
// Uses the reference table to return locality data for an object.
auto it = object_id_refs_.find(object_id);
Expand All @@ -1281,7 +1281,10 @@ absl::optional<LocalityData> ReferenceCounter::GetLocalityData(
// locations.
// - If we don't own this object, this will contain a snapshot of the object locations
// at future resolution time.
const auto &node_ids = it->second.locations;
auto node_ids = it->second.locations;
if (!it->second.spilled_node_id.IsNil()) {
node_ids.emplace(it->second.spilled_node_id);
}

// We should only reach here if we have valid locality data to return.
absl::optional<LocalityData> locality_data(
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/reference_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ class ReferenceCounter : public ReferenceCounterInterface,
///
/// \param[in] object_id Object whose locality data we want.
/// \return Locality data.
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id);
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) const;

/// Report locality data for object. This is used by the FutureResolver to report
/// locality data for borrowed refs.
Expand Down
7 changes: 5 additions & 2 deletions src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,14 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
const auto nested_refs =
VectorFromProtobuf<rpc::ObjectReference>(return_object.nested_inlined_refs());
if (return_object.in_plasma()) {
// NOTE(swang): We need to add the location of the object before marking
// it as local in the in-memory store so that the data locality policy
// will choose the right raylet for any queued dependent tasks.
const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id());
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
// Mark it as in plasma with a dummy object.
RAY_CHECK(
in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id));
const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id());
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
} else {
// NOTE(swang): If a direct object was promoted to plasma, then we do not
// record the node ID that it was pinned at, which means that we will not
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/test/lease_policy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ class MockLocalityDataProvider : public LocalityDataProviderInterface {
MockLocalityDataProvider(absl::flat_hash_map<ObjectID, LocalityData> locality_data)
: locality_data_(locality_data) {}

absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) {
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) const {
num_locality_data_fetches++;
return locality_data_[object_id];
};

~MockLocalityDataProvider() {}

int num_locality_data_fetches = 0;
absl::flat_hash_map<ObjectID, LocalityData> locality_data_;
mutable int num_locality_data_fetches = 0;
mutable absl::flat_hash_map<ObjectID, LocalityData> locality_data_;
};

absl::optional<rpc::Address> MockNodeAddrFactory(const NodeID &node_id) {
Expand Down
9 changes: 9 additions & 0 deletions src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) {
ASSERT_EQ(locality_data_obj1->nodes_containing_object,
absl::flat_hash_set<NodeID>({node1}));

// Include spilled locations in locality data.
rc->RemoveObjectLocation(obj1, node1);
locality_data_obj1 = rc->GetLocalityData(obj1);
ASSERT_EQ(locality_data_obj1->nodes_containing_object, absl::flat_hash_set<NodeID>({}));
rc->HandleObjectSpilled(obj1, "spill_loc", node1);
locality_data_obj1 = rc->GetLocalityData(obj1);
ASSERT_EQ(locality_data_obj1->nodes_containing_object,
absl::flat_hash_set<NodeID>({node1}));

// Borrowed object with defined object size and at least one node location should
// return valid locality data.
rc->AddLocalReference(obj2, "file.py:43");
Expand Down
25 changes: 25 additions & 0 deletions src/ray/core_worker/test/task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ TEST_F(TaskManagerTest, TestLineageEvicted) {
ASSERT_FALSE(reference_counter_->HasReference(return_id));
}

TEST_F(TaskManagerTest, TestLocalityDataAdded) {
auto spec = CreateTaskHelper(1, {});
auto return_id = spec.ReturnId(0);
auto node_id = NodeID::FromRandom();
int object_size = 100;
store_->GetAsync(return_id, [&](std::shared_ptr<RayObject> obj) {
// By the time the return object is available to get, we should be able
// to get the locality data too.
auto locality_data = reference_counter_->GetLocalityData(return_id);
ASSERT_TRUE(locality_data.has_value());
ASSERT_EQ(locality_data->object_size, object_size);
ASSERT_TRUE(locality_data->nodes_containing_object.contains(node_id));
});

rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
return_object->set_object_id(return_id.Binary());
return_object->set_in_plasma(true);
return_object->set_size(object_size);
rpc::Address worker_addr;
worker_addr.set_raylet_id(node_id.Binary());
manager_.AddPendingTask(rpc::Address(), spec, "", 0);
manager_.CompletePendingTask(spec.TaskId(), reply, worker_addr);
}

// Test to make sure that the task spec and dependencies for an object are
// pinned when lineage pinning is enabled in the ReferenceCounter.
TEST_F(TaskManagerLineageTest, TestLineagePinned) {
Expand Down

0 comments on commit 4942d7f

Please sign in to comment.