Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Optimize block prefetching #35568

Merged
merged 15 commits into from
Jun 1, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,4 @@ def _prefetch_blocks(
trace_deallocation(
block_ref, "block_batching._prefetch_blocks", free=eager_free
)
prefetcher.stop()
15 changes: 11 additions & 4 deletions python/ray/data/_internal/block_batching/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc

from dataclasses import dataclass
from typing import Any
from typing import Any, List

from ray.types import ObjectRef
from ray.data.block import Block, DataBatch
Expand Down Expand Up @@ -33,9 +35,14 @@ class CollatedBatch(Batch):
data: Any


class BlockPrefetcher:
class BlockPrefetcher(metaclass=abc.ABCMeta):
"""Interface for prefetching blocks."""

def prefetch_blocks(self, blocks: ObjectRef[Block]):
@abc.abstractmethod
def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
"""Prefetch the provided blocks to this node."""
raise NotImplementedError
pass

def stop(self):
"""Stop prefetching and release resources."""
pass
1 change: 1 addition & 0 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def prefetch_batches_locally(
pass
yield block_ref
trace_deallocation(block_ref, loc="iter_batches", free=eager_free)
prefetcher.stop()


def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]:
Expand Down
51 changes: 45 additions & 6 deletions python/ray/data/_internal/block_batching/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,52 @@ def execute_computation(thread_index: int):
class WaitBlockPrefetcher(BlockPrefetcher):
"""Block prefetcher using ray.wait."""

def prefetch_blocks(self, blocks: ObjectRef[Block]):
ray.wait(blocks, num_returns=1, fetch_local=True)
def __init__(self):
self._blocks = []
self._stopped = False
self._condition = threading.Condition()
self._thread = threading.Thread(
target=self._run,
name="Prefetcher",
daemon=True,
)
self._thread.start()

def _run(self):
while True:
try:
blocks_to_wait = []
with self._condition:
if len(self._blocks) > 0:
blocks_to_wait, self._blocks = self._blocks[:], []
else:
if self._stopped:
return
blocks_to_wait = []
self._condition.wait()
if len(blocks_to_wait) > 0:
ray.wait(blocks_to_wait, num_returns=1, fetch_local=True)
except Exception:
logger.exception("Error in prefetcher thread.")

def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
with self._condition:
if self._stopped:
raise RuntimeError("Prefetcher is stopped.")
self._blocks = blocks
self._condition.notify()

def stop(self):
with self._condition:
if self._stopped:
return
self._stopped = True
self._condition.notify()

def __del__(self):
self.stop()


# ray.wait doesn't work as expected, so we have an
# actor-based prefetcher as a work around. See
# https://github.com/ray-project/ray/issues/23983 for details.
class ActorBlockPrefetcher(BlockPrefetcher):
"""Block prefetcher using a local actor."""

Expand All @@ -308,7 +347,7 @@ def _get_or_create_actor_prefetcher() -> "ActorHandle":
get_if_exists=True,
).remote()

def prefetch_blocks(self, blocks: ObjectRef[Block]):
def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
self.prefetch_actor.prefetch.remote(*blocks)


Expand Down