Skip to content

Commit

Permalink
[Core][Worker_Pool] Wait for prestarted-workers for the first job and…
Browse files Browse the repository at this point in the history
… disable run_on_all_workers flaky tests (ray-project#31836)

Why are these changes needed?

ray-project#30883 we changed the behavior of first driver connect to a raylet: previously the driver connection will wait for the prestarted workers to connect before returns, and after ray-project#30883 it no longer does so. This has caused some test flakiness for
test_failure_4.py -k test_task_crash_after_raylet_dead_throws_node_died_error and test_ray_shutdown.py -k test_driver_dead

We restore the previous behavior and fixes the test flakiness.

Also ray-project#30883 added flakiness for run_functions_on_all_workers. we disable those flaky tests

Signed-off-by: Andrea Pisoni <[email protected]>
  • Loading branch information
scv119 authored and andreapiso committed Jan 22, 2023
1 parent cb2a30a commit e16eaf1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
6 changes: 2 additions & 4 deletions python/ray/tests/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def h(input_list):
ray.get([h.remote([x]), h.remote([x])])


@pytest.mark.skipif(client_test_enabled(), reason="internal api")
@pytest.mark.skip(reason="Flaky tests")
def test_caching_functions_to_run(shutdown_only):
# Test that we export functions to run on all workers before the driver
# is connected.
Expand Down Expand Up @@ -193,9 +193,7 @@ def f(worker_info):
ray._private.worker.global_worker.run_function_on_all_workers(f)


@pytest.mark.skipif(
client_test_enabled() or sys.platform == "win32", reason="internal api"
)
@pytest.mark.skip(reason="Flaky tests")
def test_running_function_on_all_workers(ray_start_regular):
def f(worker_info):
sys.path.append("fake_directory")
Expand Down
4 changes: 1 addition & 3 deletions python/ray/tests/test_failure_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def task():
) not in e.value.output.decode()


@pytest.mark.skipif(sys.platform == "win32", reason="Flaky tests on Windows")
def test_task_failure_when_driver_local_raylet_dies(ray_start_cluster):
cluster = ray_start_cluster
head = cluster.add_node(num_cpus=4, resources={"foo": 1})
Expand Down Expand Up @@ -667,13 +666,12 @@ def test_task_crash_after_raylet_dead_throws_node_died_error():
def sleeper():
import os

time.sleep(5)
time.sleep(3)
os.kill(os.getpid(), 9)

with ray.init():
ref = sleeper.remote()

time.sleep(2)
raylet = ray.nodes()[0]
kill_raylet(raylet)

Expand Down
3 changes: 2 additions & 1 deletion src/ray/raylet/worker_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker)
// This is a workaround to finish driver registration after all initial workers are
// registered to Raylet if and only if Raylet is started by a Python driver and the
// job config is not set in `ray.init(...)`.
if (first_job_ == worker->GetAssignedJobId() &&
if (worker_type == rpc::WorkerType::WORKER &&
worker->GetLanguage() == Language::PYTHON) {
if (++first_job_registered_python_worker_count_ ==
first_job_driver_wait_num_python_workers_) {
Expand Down Expand Up @@ -803,6 +803,7 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver
first_job_ = job_id;
// If the number of Python workers we need to wait is positive.
if (num_initial_python_workers_for_first_job_ > 0) {
delay_callback = true;
PrestartDefaultCpuWorkers(Language::PYTHON,
num_initial_python_workers_for_first_job_);
}
Expand Down

0 comments on commit e16eaf1

Please sign in to comment.