Skip to content

Commit

Permalink
[BUG] clean up ray scheduler threads after computing partial results (#…
Browse files Browse the repository at this point in the history
…1597)

* adds `stop_plan` method on the Ray Scheduler when is called when the
`RayRunner.run_iter` goes out of scope.
* This fixes the issues where we have a hanging thread at the end of
execution
* closes: #1591
  • Loading branch information
samster25 authored Nov 13, 2023
1 parent 819dcd8 commit 0052c17
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 21 deletions.
69 changes: 48 additions & 21 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,31 +400,32 @@ def __init__(self, max_task_backlog: int | None) -> None:

self.threads_by_df: dict[str, threading.Thread] = dict()
self.results_by_df: dict[str, Queue] = {}
self.active_by_df: dict[str, bool] = dict()

def next(self, result_uuid: str) -> ray.ObjectRef | StopIteration:
# Case: thread is terminated and no longer exists.
# Should only be hit for repeated calls to next() after StopIteration.
if result_uuid not in self.threads_by_df:
return StopIteration()

# Case: thread needs to be terminated
if not self.active_by_df.get(result_uuid, False):
return StopIteration()

# Common case: get the next result from the thread.
result = self.results_by_df[result_uuid].get()

# If there are no more results, delete the thread.
if isinstance(result, (StopIteration, Exception)):
self.threads_by_df[result_uuid].join()
del self.threads_by_df[result_uuid]

return result

def run_plan(
def start_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
psets: dict[str, ray.ObjectRef],
result_uuid: str,
results_buffer_size: int | None = None,
) -> None:
self.results_by_df[result_uuid] = Queue(maxsize=results_buffer_size or -1)
self.active_by_df[result_uuid] = True

t = threading.Thread(
target=self._run_plan,
Expand All @@ -438,6 +439,20 @@ def run_plan(
t.start()
self.threads_by_df[result_uuid] = t

def active_plans(self) -> list[str]:
return [r_uuid for r_uuid, is_active in self.active_by_df.items() if is_active]

def stop_plan(self, result_uuid: str) -> None:
if result_uuid in self.active_by_df:
# Mark df as non-active
self.active_by_df[result_uuid] = False
# wait till thread gracefully completes
self.threads_by_df[result_uuid].join()
# remove thread and history of df
del self.threads_by_df[result_uuid]
del self.active_by_df[result_uuid]
del self.results_by_df[result_uuid]

def _run_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
Expand All @@ -461,8 +476,8 @@ def _run_plan(
try:
next_step = next(tasks)

while True: # Loop: Dispatch -> await.
while True: # Loop: Dispatch (get tasks -> batch dispatch).
while self.active_by_df.get(result_uuid, False): # Loop: Dispatch -> await.
while self.active_by_df.get(result_uuid, False): # Loop: Dispatch (get tasks -> batch dispatch).
tasks_to_dispatch: list[PartitionTask] = []

cores: int = next(num_cpus_provider) - self.reserved_cores
Expand Down Expand Up @@ -619,6 +634,12 @@ def __init__(
max_task_backlog=max_task_backlog,
)

def active_plans(self) -> list[str]:
if isinstance(self.ray_context, ray.client_builder.ClientContext):
return ray.get(self.scheduler_actor.active_plans.remote())
else:
return self.scheduler.active_plans()

def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[ray.ObjectRef]:
# Optimize the logical plan.
builder = builder.optimize()
Expand All @@ -634,7 +655,7 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None
result_uuid = str(uuid.uuid4())
if isinstance(self.ray_context, ray.client_builder.ClientContext):
ray.get(
self.scheduler_actor.run_plan.remote(
self.scheduler_actor.start_plan.remote(
plan_scheduler=plan_scheduler,
psets=psets,
result_uuid=result_uuid,
Expand All @@ -643,25 +664,31 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None
)

else:
self.scheduler.run_plan(
self.scheduler.start_plan(
plan_scheduler=plan_scheduler,
psets=psets,
result_uuid=result_uuid,
results_buffer_size=results_buffer_size,
)

while True:
try:
while True:
if isinstance(self.ray_context, ray.client_builder.ClientContext):
result = ray.get(self.scheduler_actor.next.remote(result_uuid))
else:
result = self.scheduler.next(result_uuid)

if isinstance(result, StopIteration):
break
elif isinstance(result, Exception):
raise result

yield result
finally:
# Generator is out of scope, ensure that state has been cleaned up
if isinstance(self.ray_context, ray.client_builder.ClientContext):
result = ray.get(self.scheduler_actor.next.remote(result_uuid))
ray.get(self.scheduler_actor.stop_plan.remote(result_uuid))
else:
result = self.scheduler.next(result_uuid)

if isinstance(result, StopIteration):
return
elif isinstance(result, Exception):
raise result

yield result
self.scheduler.stop_plan(result_uuid)

def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]:
for ref in self.run_iter(builder, results_buffer_size=results_buffer_size):
Expand Down
48 changes: 48 additions & 0 deletions tests/ray/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import pytest

import daft
from daft.context import get_context


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_clean_up_df_show():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
df.show()
runner = get_context().runner()
assert len(runner.active_plans()) == 0


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_single_iter_partitions():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
iter = df.iter_partitions()
next(iter)
runner = get_context().runner()
assert len(runner.active_plans()) == 1
del iter
assert len(runner.active_plans()) == 0


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_multiple_iter_partitions():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
iter = df.iter_partitions()
next(iter)
runner = get_context().runner()
assert len(runner.active_plans()) == 1

df2 = daft.read_parquet([path, path])
iter2 = df2.iter_partitions()
next(iter2)
assert len(runner.active_plans()) == 2

del iter
assert len(runner.active_plans()) == 1

del iter2
assert len(runner.active_plans()) == 0

0 comments on commit 0052c17

Please sign in to comment.