Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Implement limit physical operator #34705

Merged
merged 12 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from ray.data.block import Block, BlockMetadata, List
from ray.data.datasource import ReadTask
from ray.data._internal.stats import StatsDict, DatastreamStats
from ray.data._internal.stage_impl import RandomizeBlocksStage
from ray.data._internal.stage_impl import (
RandomizeBlocksStage,
LimitStage,
)
from ray.data._internal.block_list import BlockList
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.compute import (
Expand All @@ -26,6 +29,7 @@
from ray.data._internal.memory_tracing import trace_allocation
from ray.data._internal.plan import ExecutionPlan, OneToOneStage, AllToAllStage, Stage
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.limit_operator import LimitOperator
from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.interfaces import (
Expand Down Expand Up @@ -300,6 +304,8 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
min_rows_per_bundle=stage.target_block_size,
ray_remote_args=stage.ray_remote_args,
)
elif isinstance(stage, LimitStage):
return LimitOperator(stage.limit, input_op)
elif isinstance(stage, AllToAllStage):
fn = stage.fn
block_udf = stage.block_udf
Expand Down
98 changes: 98 additions & 0 deletions python/ray/data/_internal/execution/operators/limit_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import ray
import copy
from collections import deque
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
)
from ray.data._internal.stats import StatsDict
from ray.data._internal.execution.interfaces import (
PhysicalOperator,
RefBundle,
)
from ray.data._internal.remote_fn import cached_remote_fn
from ray.types import ObjectRef
from typing import (
Deque,
List,
Optional,
Tuple,
)


class LimitOperator(PhysicalOperator):
"""Physical operator for limit."""

def __init__(
self,
limit: int,
input_op: PhysicalOperator,
):
self._limit = limit
self._consumed_rows = 0
self._buffer: Deque[RefBundle] = deque()
self._name = f"Limit[limit={limit}]"
self._output_metadata: List[BlockMetadata] = []
self._num_outputs_total = input_op.num_outputs_total()
if self._num_outputs_total is not None:
self._num_outputs_total = min(self._num_outputs_total, limit)
super().__init__(self._name, [input_op])

def _limit_reached(self) -> bool:
return self._consumed_rows >= self._limit

def add_input(self, refs: RefBundle, input_index: int) -> None:
assert not self.completed()
assert input_index == 0, input_index
if self._limit_reached():
return
out_blocks: List[ObjectRef[Block]] = []
out_metadata: List[BlockMetadata] = []
for block, metadata in refs.blocks:
num_rows = metadata.num_rows
assert num_rows is not None
if self._consumed_rows + num_rows <= self._limit:
self._consumed_rows += num_rows
out_blocks.append(block)
out_metadata.append(metadata)
self._output_metadata.append(metadata)
else:
# Slice the last block.
def slice_fn(block, metadata, num_rows) -> Tuple[Block, BlockMetadata]:
block = BlockAccessor.for_block(block).slice(0, num_rows, copy=True)
metadata = copy.deepcopy(metadata)
metadata.num_rows = num_rows
metadata.size_bytes = BlockAccessor.for_block(block).size_bytes()
return block, metadata

block, metadata_ref = cached_remote_fn(slice_fn, num_returns=2).remote(
block,
metadata,
self._limit - self._consumed_rows,
)
out_blocks.append(block)
metadata = ray.get(metadata_ref)
out_metadata.append(metadata)
self._output_metadata.append(metadata)
break
out_refs = RefBundle(
list(zip(out_blocks, out_metadata)),
owns_blocks=refs.owns_blocks,
)
self._buffer.append(out_refs)

def has_next(self) -> bool:
return len(self._buffer) > 0

def get_next(self) -> RefBundle:
return self._buffer.popleft()

def get_stats(self) -> StatsDict:
return {self._name: self._output_metadata}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the correct way to handle stats. I just follow map_operator. Is there any docs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, probably we should improve the docstring in interfaces.py


def num_outputs_total(self) -> Optional[int]:
if self._limit_reached():
return self._limit
else:
return self._num_outputs_total
16 changes: 16 additions & 0 deletions python/ray/data/_internal/logical/operators/limit_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ray.data._internal.logical.interfaces import LogicalOperator


class Limit(LogicalOperator):
"""Logical operator for limit."""

def __init__(
self,
input_op: LogicalOperator,
limit: int,
):
super().__init__(
"Limit",
[input_op],
)
self._limit = limit
43 changes: 42 additions & 1 deletion python/ray/data/_internal/stage_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.split import _split_at_indices
from ray.data._internal.split import (
_split_at_index,
_split_at_indices,
)
from ray.data._internal.block_list import BlockList
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
Expand Down Expand Up @@ -344,3 +347,41 @@ def do_sort(
do_sort,
sub_stage_names=["SortSample", "ShuffleMap", "ShuffleReduce"],
)


class LimitStage(AllToAllStage):
"""Implementation of `Datastream.limit()`."""

def __init__(self, limit: int):
self._limit = limit
super().__init__(
"Limit",
None,
self._do_limit,
)

@property
def limit(self) -> int:
return self._limit

def _do_limit(
self,
input_block_list: BlockList,
clear_input_blocks: bool,
*_,
):
if clear_input_blocks:
block_list = input_block_list.copy()
input_block_list.clear()
else:
block_list = input_block_list
block_list = block_list.truncate_by_rows(self._limit)
blocks, metadata, _, _ = _split_at_index(block_list, self._limit)
return (
BlockList(
blocks,
metadata,
owned_by_consumer=block_list._owned_by_consumer,
),
{},
)
44 changes: 9 additions & 35 deletions python/ray/data/datastream.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from ray.data._internal.logical.operators.n_ary_operator import Zip
from ray.data._internal.logical.optimizers import LogicalPlan
from ray.data._internal.logical.operators.limit_operator import Limit
from ray.data._internal.logical.operators.map_operator import (
Filter,
FlatMap,
Expand Down Expand Up @@ -81,10 +82,11 @@
RandomShuffleStage,
ZipStage,
SortStage,
LimitStage,
)
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.split import _split_at_index, _split_at_indices, _get_num_rows
from ray.data._internal.split import _split_at_indices, _get_num_rows
from ray.data._internal.stats import DatastreamStats, DatastreamStatsSummary
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum
from ray.data.block import (
Expand Down Expand Up @@ -2197,40 +2199,12 @@ def limit(self, limit: int) -> "Datastream[T]":
Returns:
The truncated datastream.
"""
start_time = time.perf_counter()
# Truncate the block list to the minimum number of blocks that contains at least
# `limit` rows.
block_list = self._plan.execute().truncate_by_rows(limit)
blocks, metadata, _, _ = _split_at_index(block_list, limit)
split_duration = time.perf_counter() - start_time
meta_for_stats = [
BlockMetadata(
num_rows=m.num_rows,
size_bytes=m.size_bytes,
schema=m.schema,
input_files=m.input_files,
exec_stats=None,
)
for m in metadata
]
datastream_stats = DatastreamStats(
stages={"Limit": meta_for_stats},
parent=self._plan.stats(),
)
datastream_stats.time_total_s = split_duration
return Datastream(
ExecutionPlan(
BlockList(
blocks,
metadata,
owned_by_consumer=block_list._owned_by_consumer,
),
datastream_stats,
run_by_consumer=block_list._owned_by_consumer,
),
self._epoch,
self._lazy,
)
plan = self._plan.with_stage(LimitStage(limit))
logical_plan = self._logical_plan
if logical_plan is not None:
op = Limit(logical_plan.dag, limit=limit)
logical_plan = LogicalPlan(op)
return Datastream(plan, self._epoch, self._lazy, logical_plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.


@ConsumptionAPI(pattern="Time complexity:")
def take_batch(
Expand Down