diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index a343c5384a3b..d0fc71b2fff4 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -14,7 +14,10 @@ from ray.data.block import Block, BlockMetadata, List from ray.data.datasource import ReadTask from ray.data._internal.stats import StatsDict, DatastreamStats -from ray.data._internal.stage_impl import RandomizeBlocksStage +from ray.data._internal.stage_impl import ( + RandomizeBlocksStage, + LimitStage, +) from ray.data._internal.block_list import BlockList from ray.data._internal.lazy_block_list import LazyBlockList from ray.data._internal.compute import ( @@ -26,6 +29,7 @@ from ray.data._internal.memory_tracing import trace_allocation from ray.data._internal.plan import ExecutionPlan, OneToOneStage, AllToAllStage, Stage from ray.data._internal.execution.operators.map_operator import MapOperator +from ray.data._internal.execution.operators.limit_operator import LimitOperator from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer from ray.data._internal.execution.interfaces import ( @@ -300,6 +304,8 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: min_rows_per_bundle=stage.target_block_size, ray_remote_args=stage.ray_remote_args, ) + elif isinstance(stage, LimitStage): + return LimitOperator(stage.limit, input_op) elif isinstance(stage, AllToAllStage): fn = stage.fn block_udf = stage.block_udf diff --git a/python/ray/data/_internal/execution/operators/limit_operator.py b/python/ray/data/_internal/execution/operators/limit_operator.py new file mode 100644 index 000000000000..6a8da6e76993 --- /dev/null +++ b/python/ray/data/_internal/execution/operators/limit_operator.py @@ -0,0 +1,98 @@ +import ray +import copy +from collections import deque +from ray.data.block import ( + Block, + BlockAccessor, + BlockMetadata, +) +from ray.data._internal.stats import StatsDict +from ray.data._internal.execution.interfaces import ( + PhysicalOperator, + RefBundle, +) +from ray.data._internal.remote_fn import cached_remote_fn +from ray.types import ObjectRef +from typing import ( + Deque, + List, + Optional, + Tuple, +) + + +class LimitOperator(PhysicalOperator): + """Physical operator for limit.""" + + def __init__( + self, + limit: int, + input_op: PhysicalOperator, + ): + self._limit = limit + self._consumed_rows = 0 + self._buffer: Deque[RefBundle] = deque() + self._name = f"Limit[limit={limit}]" + self._output_metadata: List[BlockMetadata] = [] + self._num_outputs_total = input_op.num_outputs_total() + if self._num_outputs_total is not None: + self._num_outputs_total = min(self._num_outputs_total, limit) + super().__init__(self._name, [input_op]) + + def _limit_reached(self) -> bool: + return self._consumed_rows >= self._limit + + def add_input(self, refs: RefBundle, input_index: int) -> None: + assert not self.completed() + assert input_index == 0, input_index + if self._limit_reached(): + return + out_blocks: List[ObjectRef[Block]] = [] + out_metadata: List[BlockMetadata] = [] + for block, metadata in refs.blocks: + num_rows = metadata.num_rows + assert num_rows is not None + if self._consumed_rows + num_rows <= self._limit: + self._consumed_rows += num_rows + out_blocks.append(block) + out_metadata.append(metadata) + self._output_metadata.append(metadata) + else: + # Slice the last block. + def slice_fn(block, metadata, num_rows) -> Tuple[Block, BlockMetadata]: + block = BlockAccessor.for_block(block).slice(0, num_rows, copy=True) + metadata = copy.deepcopy(metadata) + metadata.num_rows = num_rows + metadata.size_bytes = BlockAccessor.for_block(block).size_bytes() + return block, metadata + + block, metadata_ref = cached_remote_fn(slice_fn, num_returns=2).remote( + block, + metadata, + self._limit - self._consumed_rows, + ) + out_blocks.append(block) + metadata = ray.get(metadata_ref) + out_metadata.append(metadata) + self._output_metadata.append(metadata) + break + out_refs = RefBundle( + list(zip(out_blocks, out_metadata)), + owns_blocks=refs.owns_blocks, + ) + self._buffer.append(out_refs) + + def has_next(self) -> bool: + return len(self._buffer) > 0 + + def get_next(self) -> RefBundle: + return self._buffer.popleft() + + def get_stats(self) -> StatsDict: + return {self._name: self._output_metadata} + + def num_outputs_total(self) -> Optional[int]: + if self._limit_reached(): + return self._limit + else: + return self._num_outputs_total diff --git a/python/ray/data/_internal/logical/operators/limit_operator.py b/python/ray/data/_internal/logical/operators/limit_operator.py new file mode 100644 index 000000000000..c7d9690ad8b7 --- /dev/null +++ b/python/ray/data/_internal/logical/operators/limit_operator.py @@ -0,0 +1,16 @@ +from ray.data._internal.logical.interfaces import LogicalOperator + + +class Limit(LogicalOperator): + """Logical operator for limit.""" + + def __init__( + self, + input_op: LogicalOperator, + limit: int, + ): + super().__init__( + "Limit", + [input_op], + ) + self._limit = limit diff --git a/python/ray/data/_internal/stage_impl.py b/python/ray/data/_internal/stage_impl.py index 4a89454846c8..472853055301 100644 --- a/python/ray/data/_internal/stage_impl.py +++ b/python/ray/data/_internal/stage_impl.py @@ -8,7 +8,10 @@ PushBasedShufflePartitionOp, SimpleShufflePartitionOp, ) -from ray.data._internal.split import _split_at_indices +from ray.data._internal.split import ( + _split_at_index, + _split_at_indices, +) from ray.data._internal.block_list import BlockList from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.execution.interfaces import TaskContext @@ -344,3 +347,41 @@ def do_sort( do_sort, sub_stage_names=["SortSample", "ShuffleMap", "ShuffleReduce"], ) + + +class LimitStage(AllToAllStage): + """Implementation of `Datastream.limit()`.""" + + def __init__(self, limit: int): + self._limit = limit + super().__init__( + "Limit", + None, + self._do_limit, + ) + + @property + def limit(self) -> int: + return self._limit + + def _do_limit( + self, + input_block_list: BlockList, + clear_input_blocks: bool, + *_, + ): + if clear_input_blocks: + block_list = input_block_list.copy() + input_block_list.clear() + else: + block_list = input_block_list + block_list = block_list.truncate_by_rows(self._limit) + blocks, metadata, _, _ = _split_at_index(block_list, self._limit) + return ( + BlockList( + blocks, + metadata, + owned_by_consumer=block_list._owned_by_consumer, + ), + {}, + ) diff --git a/python/ray/data/datastream.py b/python/ray/data/datastream.py index 2803c0b520ca..b3c6aa3c9df8 100644 --- a/python/ray/data/datastream.py +++ b/python/ray/data/datastream.py @@ -36,6 +36,7 @@ ) from ray.data._internal.logical.operators.n_ary_operator import Zip from ray.data._internal.logical.optimizers import LogicalPlan +from ray.data._internal.logical.operators.limit_operator import Limit from ray.data._internal.logical.operators.map_operator import ( Filter, FlatMap, @@ -81,10 +82,11 @@ RandomShuffleStage, ZipStage, SortStage, + LimitStage, ) from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn -from ray.data._internal.split import _split_at_index, _split_at_indices, _get_num_rows +from ray.data._internal.split import _split_at_indices, _get_num_rows from ray.data._internal.stats import DatastreamStats, DatastreamStatsSummary from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum from ray.data.block import ( @@ -2197,40 +2199,12 @@ def limit(self, limit: int) -> "Datastream[T]": Returns: The truncated datastream. """ - start_time = time.perf_counter() - # Truncate the block list to the minimum number of blocks that contains at least - # `limit` rows. - block_list = self._plan.execute().truncate_by_rows(limit) - blocks, metadata, _, _ = _split_at_index(block_list, limit) - split_duration = time.perf_counter() - start_time - meta_for_stats = [ - BlockMetadata( - num_rows=m.num_rows, - size_bytes=m.size_bytes, - schema=m.schema, - input_files=m.input_files, - exec_stats=None, - ) - for m in metadata - ] - datastream_stats = DatastreamStats( - stages={"Limit": meta_for_stats}, - parent=self._plan.stats(), - ) - datastream_stats.time_total_s = split_duration - return Datastream( - ExecutionPlan( - BlockList( - blocks, - metadata, - owned_by_consumer=block_list._owned_by_consumer, - ), - datastream_stats, - run_by_consumer=block_list._owned_by_consumer, - ), - self._epoch, - self._lazy, - ) + plan = self._plan.with_stage(LimitStage(limit)) + logical_plan = self._logical_plan + if logical_plan is not None: + op = Limit(logical_plan.dag, limit=limit) + logical_plan = LogicalPlan(op) + return Datastream(plan, self._epoch, self._lazy, logical_plan) @ConsumptionAPI(pattern="Time complexity:") def take_batch(