-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data][Split optimization] don't generate empty blocks #26768
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
import itertools | ||
import logging | ||
from typing import Iterable, Tuple, List | ||
|
||
import ray | ||
from ray.data._internal.remote_fn import cached_remote_fn | ||
from ray.data.block import ( | ||
Block, | ||
BlockPartition, | ||
BlockAccessor, | ||
BlockExecStats, | ||
BlockMetadata, | ||
|
@@ -15,7 +17,7 @@ | |
|
||
|
||
def _calculate_blocks_rows( | ||
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], | ||
blocks_with_metadata: BlockPartition, | ||
) -> List[int]: | ||
"""Calculate the number of rows for a list of blocks with metadata.""" | ||
get_num_rows = cached_remote_fn(_get_num_rows) | ||
|
@@ -24,6 +26,7 @@ def _calculate_blocks_rows( | |
if metadata.num_rows is None: | ||
# Need to fetch number of rows. | ||
num_rows = ray.get(get_num_rows.remote(block)) | ||
metadata.num_rows = num_rows | ||
else: | ||
num_rows = metadata.num_rows | ||
block_rows.append(num_rows) | ||
|
@@ -88,16 +91,15 @@ def _split_single_block( | |
block_id: int, | ||
block: Block, | ||
meta: BlockMetadata, | ||
block_row: int, | ||
split_indices: List[int], | ||
) -> Tuple[int, List[Tuple[ObjectRef[Block], BlockMetadata]]]: | ||
) -> Tuple[int, BlockPartition]: | ||
"""Split the provided block at the given indices.""" | ||
split_result = [] | ||
block_accessor = BlockAccessor.for_block(block) | ||
prev_index = 0 | ||
# append one more entry at the last so we don't | ||
# need handle empty edge case. | ||
split_indices.append(block_row) | ||
split_indices.append(meta.num_rows) | ||
for index in split_indices: | ||
logger.debug(f"slicing block {prev_index}:{index}") | ||
stats = BlockExecStats.builder() | ||
|
@@ -115,23 +117,38 @@ def _split_single_block( | |
return (block_id, split_result) | ||
|
||
|
||
def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]: | ||
"""drop split indices that creates empty block split. This could happen when there | ||
are duplicated indices, or index equal to 0 (start of the block) or num_block_rows | ||
(end of the block). | ||
""" | ||
prev_index = -1 | ||
optimized_indices = [] | ||
for index in block_split_indices: | ||
if index == 0 or index == num_rows: | ||
continue | ||
if index == prev_index: | ||
continue | ||
optimized_indices.append(index) | ||
prev_index = index | ||
return optimized_indices | ||
|
||
|
||
def _split_all_blocks( | ||
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], | ||
block_rows: List[int], | ||
blocks_with_metadata: BlockPartition, | ||
per_block_split_indices: List[List[int]], | ||
) -> List[Tuple[ObjectRef[Block], BlockMetadata]]: | ||
) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]: | ||
"""Split all the input blocks based on the split indices""" | ||
split_single_block = cached_remote_fn(_split_single_block) | ||
|
||
all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]] = [ | ||
None | ||
] * len(blocks_with_metadata) | ||
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata) | ||
|
||
split_single_block_futures = [] | ||
|
||
for block_id, block_split_indices in enumerate(per_block_split_indices): | ||
(block_ref, meta) = blocks_with_metadata[block_id] | ||
block_row = block_rows[block_id] | ||
block_row = meta.num_rows | ||
block_split_indices = _drop_empty_block_split(block_split_indices, block_row) | ||
if len(block_split_indices) == 0: | ||
# optimization: if no split is needed, we just need to add it to the | ||
# result | ||
|
@@ -143,46 +160,44 @@ def _split_all_blocks( | |
block_id, | ||
block_ref, | ||
meta, | ||
block_row, | ||
block_split_indices, | ||
) | ||
) | ||
if split_single_block_futures: | ||
split_single_block_results = ray.get(split_single_block_futures) | ||
for block_id, block_split_result in split_single_block_results: | ||
all_blocks_split_results[block_id] = block_split_result | ||
return all_blocks_split_results | ||
return itertools.chain.from_iterable(all_blocks_split_results) | ||
|
||
|
||
def _generate_global_split_results( | ||
all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]], | ||
all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]], | ||
global_split_sizes: List[int], | ||
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]: | ||
"""Reassemble per block's split result into final split result.""" | ||
result_blocks = [] | ||
result_metas = [] | ||
|
||
current_blocks = [] | ||
current_meta = [] | ||
|
||
if len(all_blocks_split_results) == 0: | ||
return ([], []) | ||
|
||
for single_block_split_result in all_blocks_split_results: | ||
assert len(single_block_split_result) > 0 | ||
for i, (block, meta) in enumerate(single_block_split_result): | ||
# we should create a new global split whenever | ||
# we encountered a new local split in the per block | ||
# split result. | ||
if i != 0: | ||
result_blocks.append(current_blocks) | ||
result_metas.append(current_meta) | ||
current_blocks = [] | ||
current_meta = [] | ||
current_blocks.append(block) | ||
current_split_size = 0 | ||
current_split_id = 0 | ||
|
||
while current_split_id < len(global_split_sizes): | ||
if current_split_size >= global_split_sizes[current_split_id]: | ||
assert current_split_size == global_split_sizes[current_split_id] | ||
result_blocks.append(current_blocks) | ||
result_metas.append(current_meta) | ||
|
||
current_blocks = [] | ||
current_meta = [] | ||
current_split_size = 0 | ||
current_split_id += 1 | ||
else: | ||
(block_ref, meta) = next(all_blocks_split_results) | ||
current_blocks.append(block_ref) | ||
current_meta.append(meta) | ||
|
||
assert len(current_blocks) > 0 | ||
result_blocks.append(current_blocks) | ||
result_metas.append(current_meta) | ||
current_split_size += meta.num_rows | ||
|
||
return result_blocks, result_metas | ||
|
||
|
@@ -205,20 +220,25 @@ def _split_at_indices( | |
# phase 1: calculate the per block split indices. | ||
blocks_with_metadata = list(blocks_with_metadata) | ||
if len(blocks_with_metadata) == 0: | ||
return ([] * (len(indices) + 1), [] * (len(indices) + 1)) | ||
return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1)) | ||
block_rows: List[int] = _calculate_blocks_rows(blocks_with_metadata) | ||
valid_indices = _generate_valid_indices(block_rows, indices) | ||
per_block_split_indices: List[List[int]] = _generate_per_block_split_indices( | ||
block_rows, valid_indices | ||
) | ||
|
||
# phase 2: split each block based on the indices from previous step. | ||
all_blocks_split_results: List[ | ||
List[Tuple[ObjectRef[Block], BlockMetadata]] | ||
] = _split_all_blocks(blocks_with_metadata, block_rows, per_block_split_indices) | ||
all_blocks_split_results: Iterable[ | ||
Tuple[ObjectRef[Block], BlockMetadata] | ||
] = _split_all_blocks(blocks_with_metadata, per_block_split_indices) | ||
|
||
# phase 3: generate the final split. | ||
return _generate_global_split_results(all_blocks_split_results) | ||
|
||
# first calculate the size for each split. | ||
helper = [0] + valid_indices + [sum(block_rows)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))] | ||
|
||
return _generate_global_split_results(all_blocks_split_results, split_sizes) | ||
|
||
|
||
def _get_num_rows(block: Block) -> int: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,19 @@ | |
|
||
import numpy as np | ||
import pytest | ||
from ray.data.block import BlockMetadata | ||
|
||
import ray | ||
from ray.data._internal.block_list import BlockList | ||
from ray.data._internal.plan import ExecutionPlan | ||
from ray.data._internal.stats import DatasetStats | ||
from ray.data._internal.split import ( | ||
_drop_empty_block_split, | ||
_generate_valid_indices, | ||
_generate_per_block_split_indices, | ||
_generate_global_split_results, | ||
_split_single_block, | ||
_split_at_indices, | ||
) | ||
from ray.data.block import BlockAccessor | ||
from ray.data.dataset import Dataset | ||
|
@@ -486,3 +491,124 @@ def test_generate_per_block_split_indices(): | |
[3, 3, 3, 1], [3, 10, 10] | ||
) | ||
assert [[], [], [], []] == _generate_per_block_split_indices([3, 3, 3, 1], []) | ||
|
||
|
||
def _create_meta(num_rows): | ||
return BlockMetadata( | ||
num_rows=num_rows, | ||
size_bytes=None, | ||
schema=None, | ||
input_files=None, | ||
exec_stats=None, | ||
) | ||
|
||
|
||
def _create_block(data): | ||
return (ray.put(data), _create_meta(len(data))) | ||
|
||
|
||
def test_split_single_block(ray_start_regular_shared): | ||
block = [1, 2, 3] | ||
meta = _create_meta(3) | ||
|
||
block_id, splits = ray.get( | ||
ray.remote(_split_single_block).remote(234, block, meta, []) | ||
) | ||
assert 234 == block_id | ||
assert len(splits) == 1 | ||
assert ray.get(splits[0][0]) == [1, 2, 3] | ||
assert splits[0][1].num_rows == 3 | ||
|
||
block_id, splits = ray.get( | ||
ray.remote(_split_single_block).remote(234, block, meta, [1]) | ||
) | ||
assert 234 == block_id | ||
assert len(splits) == 2 | ||
assert ray.get(splits[0][0]) == [1] | ||
assert splits[0][1].num_rows == 1 | ||
assert ray.get(splits[1][0]) == [2, 3] | ||
assert splits[1][1].num_rows == 2 | ||
|
||
block_id, splits = ray.get( | ||
ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3]) | ||
) | ||
assert 234 == block_id | ||
assert len(splits) == 5 | ||
assert ray.get(splits[0][0]) == [] | ||
assert ray.get(splits[1][0]) == [1] | ||
assert ray.get(splits[2][0]) == [] | ||
assert ray.get(splits[3][0]) == [2, 3] | ||
assert ray.get(splits[4][0]) == [] | ||
|
||
block = [] | ||
meta = _create_meta(0) | ||
|
||
block_id, splits = ray.get( | ||
ray.remote(_split_single_block).remote(234, block, meta, [0]) | ||
) | ||
assert 234 == block_id | ||
assert len(splits) == 2 | ||
assert ray.get(splits[0][0]) == [] | ||
assert ray.get(splits[1][0]) == [] | ||
|
||
|
||
def test_drop_empty_block_split(): | ||
assert [1, 2] == _drop_empty_block_split([0, 1, 2, 3], 3) | ||
assert [1, 2] == _drop_empty_block_split([1, 1, 2, 2], 3) | ||
assert [] == _drop_empty_block_split([0], 0) | ||
|
||
|
||
def verify_splits(splits, blocks_by_split): | ||
assert len(splits) == len(blocks_by_split) | ||
for blocks, (block_refs, meta) in zip(blocks_by_split, splits): | ||
assert len(blocks) == len(block_refs) | ||
assert len(blocks) == len(meta) | ||
for block, block_ref, meta in zip(blocks, block_refs, meta): | ||
assert ray.get(block_ref) == block | ||
assert meta.num_rows == len(block) | ||
|
||
|
||
def test_generate_global_split_results(ray_start_regular_shared): | ||
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])] | ||
|
||
splits = list(zip(*_generate_global_split_results(iter(inputs), [1, 2, 1]))) | ||
verify_splits(splits, [[[1]], [[2, 3]], [[4]]]) | ||
|
||
splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 1]))) | ||
verify_splits(splits, [[[1], [2, 3]], [[4]]]) | ||
|
||
splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 0, 1]))) | ||
verify_splits(splits, [[[1], [2, 3]], [], [[4]]]) | ||
|
||
inputs = [] | ||
splits = list(zip(*_generate_global_split_results(iter(inputs), [0, 0]))) | ||
verify_splits(splits, [[], []]) | ||
|
||
|
||
def test_private_split_at_indices(ray_start_regular_shared): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note we have very comprehensive test for test_split_at_indices. this is only for edge cases and block distributions. |
||
inputs = [] | ||
splits = list(zip(*_split_at_indices(iter(inputs), [0]))) | ||
verify_splits(splits, [[], []]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), []))) | ||
verify_splits(splits, [[]]) | ||
|
||
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])] | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), [1]))) | ||
verify_splits(splits, [[[1]], [[2, 3], [4]]]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), [2]))) | ||
verify_splits(splits, [[[1], [2]], [[3], [4]]]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), [1]))) | ||
verify_splits(splits, [[[1]], [[2, 3], [4]]]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), [2, 2]))) | ||
verify_splits(splits, [[[1], [2]], [], [[3], [4]]]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), []))) | ||
verify_splits(splits, [[[1], [2, 3], [4]]]) | ||
|
||
splits = list(zip(*_split_at_indices(iter(inputs), [0, 4]))) | ||
verify_splits(splits, [[], [[1], [2, 3], [4]], []]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to have implication on public API Dataset.split_at_indices? It currently can generate empty blocks but this will change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about dropping only the ones for start/end of block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm fine with completely dropping empty splits (unless they have a utility that I'm not aware of). In that case, we can just go to Dataset.split_at_indices and revamp the semantics there (we just need to de-dup the global indices and document the API that it will do so).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternative: we do this internally as this PR, but when returning Dataset.split_at_indices we create empty splits to maintain the existing semantics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jianoaix ah this is internal change but not change the split_at_indices public api. So what happens is previously we will have dataset who contains empty block (block with 0 rows), now we remove the block but the number of (split) dataset is still the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this PR does not change
split_at_indices
semantics.Wondering would it be better just embed this function's logic inside
_generate_per_block_split_indices()
L70-72 ? We can just avoid adding these unuseful indices in the first place.