From 0052c174a212fda1b4843882a769d3f3ab9f0b43 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 14 Nov 2023 01:20:53 +0530 Subject: [PATCH] [BUG] clean up ray scheduler threads after computing partial results (#1597) * adds `stop_plan` method on the Ray Scheduler when is called when the `RayRunner.run_iter` goes out of scope. * This fixes the issues where we have a hanging thread at the end of execution * closes: https://github.com/Eventual-Inc/Daft/issues/1591 --- daft/runners/ray_runner.py | 69 ++++++++++++++++++++++++++------------ tests/ray/runner.py | 48 ++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 21 deletions(-) create mode 100644 tests/ray/runner.py diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 3a8038ce1e..ba2b4fb7c7 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -400,6 +400,7 @@ def __init__(self, max_task_backlog: int | None) -> None: self.threads_by_df: dict[str, threading.Thread] = dict() self.results_by_df: dict[str, Queue] = {} + self.active_by_df: dict[str, bool] = dict() def next(self, result_uuid: str) -> ray.ObjectRef | StopIteration: # Case: thread is terminated and no longer exists. @@ -407,17 +408,16 @@ def next(self, result_uuid: str) -> ray.ObjectRef | StopIteration: if result_uuid not in self.threads_by_df: return StopIteration() + # Case: thread needs to be terminated + if not self.active_by_df.get(result_uuid, False): + return StopIteration() + # Common case: get the next result from the thread. result = self.results_by_df[result_uuid].get() - # If there are no more results, delete the thread. - if isinstance(result, (StopIteration, Exception)): - self.threads_by_df[result_uuid].join() - del self.threads_by_df[result_uuid] - return result - def run_plan( + def start_plan( self, plan_scheduler: PhysicalPlanScheduler, psets: dict[str, ray.ObjectRef], @@ -425,6 +425,7 @@ def run_plan( results_buffer_size: int | None = None, ) -> None: self.results_by_df[result_uuid] = Queue(maxsize=results_buffer_size or -1) + self.active_by_df[result_uuid] = True t = threading.Thread( target=self._run_plan, @@ -438,6 +439,20 @@ def run_plan( t.start() self.threads_by_df[result_uuid] = t + def active_plans(self) -> list[str]: + return [r_uuid for r_uuid, is_active in self.active_by_df.items() if is_active] + + def stop_plan(self, result_uuid: str) -> None: + if result_uuid in self.active_by_df: + # Mark df as non-active + self.active_by_df[result_uuid] = False + # wait till thread gracefully completes + self.threads_by_df[result_uuid].join() + # remove thread and history of df + del self.threads_by_df[result_uuid] + del self.active_by_df[result_uuid] + del self.results_by_df[result_uuid] + def _run_plan( self, plan_scheduler: PhysicalPlanScheduler, @@ -461,8 +476,8 @@ def _run_plan( try: next_step = next(tasks) - while True: # Loop: Dispatch -> await. - while True: # Loop: Dispatch (get tasks -> batch dispatch). + while self.active_by_df.get(result_uuid, False): # Loop: Dispatch -> await. + while self.active_by_df.get(result_uuid, False): # Loop: Dispatch (get tasks -> batch dispatch). tasks_to_dispatch: list[PartitionTask] = [] cores: int = next(num_cpus_provider) - self.reserved_cores @@ -619,6 +634,12 @@ def __init__( max_task_backlog=max_task_backlog, ) + def active_plans(self) -> list[str]: + if isinstance(self.ray_context, ray.client_builder.ClientContext): + return ray.get(self.scheduler_actor.active_plans.remote()) + else: + return self.scheduler.active_plans() + def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[ray.ObjectRef]: # Optimize the logical plan. builder = builder.optimize() @@ -634,7 +655,7 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None result_uuid = str(uuid.uuid4()) if isinstance(self.ray_context, ray.client_builder.ClientContext): ray.get( - self.scheduler_actor.run_plan.remote( + self.scheduler_actor.start_plan.remote( plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, @@ -643,25 +664,31 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None ) else: - self.scheduler.run_plan( + self.scheduler.start_plan( plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, results_buffer_size=results_buffer_size, ) - - while True: + try: + while True: + if isinstance(self.ray_context, ray.client_builder.ClientContext): + result = ray.get(self.scheduler_actor.next.remote(result_uuid)) + else: + result = self.scheduler.next(result_uuid) + + if isinstance(result, StopIteration): + break + elif isinstance(result, Exception): + raise result + + yield result + finally: + # Generator is out of scope, ensure that state has been cleaned up if isinstance(self.ray_context, ray.client_builder.ClientContext): - result = ray.get(self.scheduler_actor.next.remote(result_uuid)) + ray.get(self.scheduler_actor.stop_plan.remote(result_uuid)) else: - result = self.scheduler.next(result_uuid) - - if isinstance(result, StopIteration): - return - elif isinstance(result, Exception): - raise result - - yield result + self.scheduler.stop_plan(result_uuid) def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]: for ref in self.run_iter(builder, results_buffer_size=results_buffer_size): diff --git a/tests/ray/runner.py b/tests/ray/runner.py new file mode 100644 index 0000000000..16026db235 --- /dev/null +++ b/tests/ray/runner.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import pytest + +import daft +from daft.context import get_context + + +@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") +def test_active_plan_clean_up_df_show(): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + df.show() + runner = get_context().runner() + assert len(runner.active_plans()) == 0 + + +@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") +def test_active_plan_single_iter_partitions(): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + iter = df.iter_partitions() + next(iter) + runner = get_context().runner() + assert len(runner.active_plans()) == 1 + del iter + assert len(runner.active_plans()) == 0 + + +@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") +def test_active_plan_multiple_iter_partitions(): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + iter = df.iter_partitions() + next(iter) + runner = get_context().runner() + assert len(runner.active_plans()) == 1 + + df2 = daft.read_parquet([path, path]) + iter2 = df2.iter_partitions() + next(iter2) + assert len(runner.active_plans()) == 2 + + del iter + assert len(runner.active_plans()) == 1 + + del iter2 + assert len(runner.active_plans()) == 0