Skip to content

Commit

Permalink
[core] Fix bugs in data locality (#24698)
Browse files Browse the repository at this point in the history
This fixes two bugs in data locality:

    When a dependent task is already in the CoreWorker's queue, we ran the data locality policy to choose a raylet before we added the first location for the dependency, so it would appear as if the dependency was not available anywhere.
    The locality policy did not take into account spilled locations.

Added C++ unit tests and Python tests for the above.
Related issue number

Fixes #24267.
  • Loading branch information
stephanie-wang authored May 20, 2022
1 parent 5ac29c0 commit eaec96d
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 10 deletions.
44 changes: 44 additions & 0 deletions python/ray/tests/test_scheduling_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,50 @@ def check_backlog_info():
cluster.shutdown()


def test_data_locality_spilled_objects(
ray_start_cluster_enabled, fs_only_object_spilling_config
):
cluster = ray_start_cluster_enabled
object_spilling_config, _ = fs_only_object_spilling_config
cluster.add_node(
num_cpus=1,
object_store_memory=100 * 1024 * 1024,
_system_config={
"min_spilling_size": 1,
"object_spilling_config": object_spilling_config,
},
)
ray.init(cluster.address)
cluster.add_node(
num_cpus=1, object_store_memory=100 * 1024 * 1024, resources={"remote": 1}
)

@ray.remote(resources={"remote": 1})
def f():
return (
np.zeros(50 * 1024 * 1024, dtype=np.uint8),
ray.runtime_context.get_runtime_context().node_id,
)

@ray.remote
def check_locality(x):
_, node_id = x
assert node_id == ray.runtime_context.get_runtime_context().node_id

# Check locality works when dependent task is already submitted by the time
# the upstream task finishes.
for _ in range(5):
ray.get(check_locality.remote(f.remote()))

# Check locality works when some objects were spilled.
xs = [f.remote() for _ in range(5)]
ray.wait(xs, num_returns=len(xs), fetch_local=False)
for i, x in enumerate(xs):
task = check_locality.remote(x)
print(i, x, task)
ray.get(task)


if __name__ == "__main__":
import pytest

Expand Down
2 changes: 1 addition & 1 deletion src/mock/ray/core_worker/lease_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MockLocalityDataProviderInterface : public LocalityDataProviderInterface {
MOCK_METHOD(absl::optional<LocalityData>,
GetLocalityData,
(const ObjectID &object_id),
(override));
(const override));
};

} // namespace core
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/lease_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ struct LocalityData {
/// Interface for providers of locality data to the lease policy.
class LocalityDataProviderInterface {
public:
virtual absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) = 0;
virtual absl::optional<LocalityData> GetLocalityData(
const ObjectID &object_id) const = 0;

virtual ~LocalityDataProviderInterface() {}
};
Expand Down
7 changes: 5 additions & 2 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id,
}

absl::optional<LocalityData> ReferenceCounter::GetLocalityData(
const ObjectID &object_id) {
const ObjectID &object_id) const {
absl::MutexLock lock(&mutex_);
// Uses the reference table to return locality data for an object.
auto it = object_id_refs_.find(object_id);
Expand All @@ -1281,7 +1281,10 @@ absl::optional<LocalityData> ReferenceCounter::GetLocalityData(
// locations.
// - If we don't own this object, this will contain a snapshot of the object locations
// at future resolution time.
const auto &node_ids = it->second.locations;
auto node_ids = it->second.locations;
if (!it->second.spilled_node_id.IsNil()) {
node_ids.emplace(it->second.spilled_node_id);
}

// We should only reach here if we have valid locality data to return.
absl::optional<LocalityData> locality_data(
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/reference_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ class ReferenceCounter : public ReferenceCounterInterface,
///
/// \param[in] object_id Object whose locality data we want.
/// \return Locality data.
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id);
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) const;

/// Report locality data for object. This is used by the FutureResolver to report
/// locality data for borrowed refs.
Expand Down
7 changes: 5 additions & 2 deletions src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,14 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
const auto nested_refs =
VectorFromProtobuf<rpc::ObjectReference>(return_object.nested_inlined_refs());
if (return_object.in_plasma()) {
// NOTE(swang): We need to add the location of the object before marking
// it as local in the in-memory store so that the data locality policy
// will choose the right raylet for any queued dependent tasks.
const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id());
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
// Mark it as in plasma with a dummy object.
RAY_CHECK(
in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id));
const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id());
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
} else {
// NOTE(swang): If a direct object was promoted to plasma, then we do not
// record the node ID that it was pinned at, which means that we will not
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/test/lease_policy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ class MockLocalityDataProvider : public LocalityDataProviderInterface {
MockLocalityDataProvider(absl::flat_hash_map<ObjectID, LocalityData> locality_data)
: locality_data_(locality_data) {}

absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) {
absl::optional<LocalityData> GetLocalityData(const ObjectID &object_id) const {
num_locality_data_fetches++;
return locality_data_[object_id];
};

~MockLocalityDataProvider() {}

int num_locality_data_fetches = 0;
absl::flat_hash_map<ObjectID, LocalityData> locality_data_;
mutable int num_locality_data_fetches = 0;
mutable absl::flat_hash_map<ObjectID, LocalityData> locality_data_;
};

absl::optional<rpc::Address> MockNodeAddrFactory(const NodeID &node_id) {
Expand Down
9 changes: 9 additions & 0 deletions src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) {
ASSERT_EQ(locality_data_obj1->nodes_containing_object,
absl::flat_hash_set<NodeID>({node1}));

// Include spilled locations in locality data.
rc->RemoveObjectLocation(obj1, node1);
locality_data_obj1 = rc->GetLocalityData(obj1);
ASSERT_EQ(locality_data_obj1->nodes_containing_object, absl::flat_hash_set<NodeID>({}));
rc->HandleObjectSpilled(obj1, "spill_loc", node1);
locality_data_obj1 = rc->GetLocalityData(obj1);
ASSERT_EQ(locality_data_obj1->nodes_containing_object,
absl::flat_hash_set<NodeID>({node1}));

// Borrowed object with defined object size and at least one node location should
// return valid locality data.
rc->AddLocalReference(obj2, "file.py:43");
Expand Down
25 changes: 25 additions & 0 deletions src/ray/core_worker/test/task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ TEST_F(TaskManagerTest, TestLineageEvicted) {
ASSERT_FALSE(reference_counter_->HasReference(return_id));
}

TEST_F(TaskManagerTest, TestLocalityDataAdded) {
auto spec = CreateTaskHelper(1, {});
auto return_id = spec.ReturnId(0);
auto node_id = NodeID::FromRandom();
int object_size = 100;
store_->GetAsync(return_id, [&](std::shared_ptr<RayObject> obj) {
// By the time the return object is available to get, we should be able
// to get the locality data too.
auto locality_data = reference_counter_->GetLocalityData(return_id);
ASSERT_TRUE(locality_data.has_value());
ASSERT_EQ(locality_data->object_size, object_size);
ASSERT_TRUE(locality_data->nodes_containing_object.contains(node_id));
});

rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
return_object->set_object_id(return_id.Binary());
return_object->set_in_plasma(true);
return_object->set_size(object_size);
rpc::Address worker_addr;
worker_addr.set_raylet_id(node_id.Binary());
manager_.AddPendingTask(rpc::Address(), spec, "", 0);
manager_.CompletePendingTask(spec.TaskId(), reply, worker_addr);
}

// Test to make sure that the task spec and dependencies for an object are
// pinned when lineage pinning is enabled in the ReferenceCounter.
TEST_F(TaskManagerLineageTest, TestLineagePinned) {
Expand Down

0 comments on commit eaec96d

Please sign in to comment.