diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 7466ec6bcbae..acb05d271e30 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -1,6 +1,7 @@ import pytest import numpy as np import sys +import time import ray @@ -179,6 +180,26 @@ def static(num_returns): ray.get(static.remote(3)) +def test_dynamic_generator_distributed(ray_start_cluster): + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + cluster.add_node(num_cpus=1) + cluster.wait_for_nodes() + + @ray.remote(num_returns="dynamic") + def dynamic_generator(num_returns): + for i in range(num_returns): + yield np.ones(1_000_000, dtype=np.int8) * i + time.sleep(0.1) + + gen = ray.get(dynamic_generator.remote(3)) + for i, ref in enumerate(gen): + # Check that we can fetch the values from a different node. + assert ray.get(ref)[0] == i + + def test_dynamic_generator_reconstruction(ray_start_cluster): config = { "num_heartbeats_timeout": 10, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 8081f8dd4eb6..37b6accbdebc 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2976,12 +2976,6 @@ void CoreWorker::AddSpilledObjectLocationOwner( // object is spilled before the reply from the task that created the // object. Add the dynamically created object to our ref counter so that we // know that it exists. - // NOTE(swang): We don't need to do this for in-plasma object locations because: - // 1) We will add the primary copy as a location when processing the task - // reply. - // 2) It is not possible to copy the object to a second location until - // after the owner has added the object to the ref count table (since no - // raylet can get the current location of the object until this happens). RAY_CHECK(!generator_id->IsNil()); reference_counter_->AddDynamicReturn(object_id, *generator_id); } @@ -3005,6 +2999,17 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, if (!reference_exists) { RAY_LOG(DEBUG) << "Object " + object_id.Hex() + " not found"; } + + // For generator tasks where we haven't yet received the task reply, the + // internal ObjectRefs may not be added yet, so we don't find out about these + // until the task finishes. + const auto &maybe_generator_id = task_manager_->TaskGeneratorId(object_id.TaskId()); + if (!maybe_generator_id.IsNil()) { + // The task is a generator and may not have finished yet. Add the internal + // ObjectID so that we can update its location. + reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); + RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); + } } void CoreWorker::RemoveObjectLocationOwner(const ObjectID &object_id, diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 3f529d2bcbef..ebb499b251c5 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -809,5 +809,17 @@ void TaskManager::FillTaskInfo(rpc::GetCoreWorkerStatsReply *reply, reply->set_tasks_total(total); } +ObjectID TaskManager::TaskGeneratorId(const TaskID &task_id) const { + absl::MutexLock lock(&mu_); + auto it = submissible_tasks_.find(task_id); + if (it == submissible_tasks_.end()) { + return ObjectID::Nil(); + } + if (!it->second.spec.ReturnsDynamic()) { + return ObjectID::Nil(); + } + return it->second.spec.ReturnId(0); +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 26554be61b51..af48f815ddb0 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -284,13 +284,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Fill every task information of the current worker to GetCoreWorkerStatsReply. void FillTaskInfo(rpc::GetCoreWorkerStatsReply *reply, const int64_t limit) const; - /// Update nested ref count info and store the in-memory value for a task's - /// return object. Returns true if the task's return object was returned - /// directly by value. - bool HandleTaskReturn(const ObjectID &object_id, - const rpc::ReturnObject &return_object, - const NodeID &worker_raylet_id, - bool store_in_plasma); + /// Returns the generator ID that contains the dynamically allocated + /// ObjectRefs, if the task is dynamic. Else, returns Nil. + ObjectID TaskGeneratorId(const TaskID &task_id) const; private: struct TaskEntry { @@ -367,6 +363,14 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa rpc::TaskStatus status = rpc::TaskStatus::PENDING_ARGS_AVAIL; }; + /// Update nested ref count info and store the in-memory value for a task's + /// return object. Returns true if the task's return object was returned + /// directly by value. + bool HandleTaskReturn(const ObjectID &object_id, + const rpc::ReturnObject &return_object, + const NodeID &worker_raylet_id, + bool store_in_plasma) LOCKS_EXCLUDED(mu_); + /// Remove a lineage reference to this object ID. This should be called /// whenever a task that depended on this object ID can no longer be retried. ///