Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data][Split optimization] don't generate empty blocks #26768

Merged
merged 4 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 59 additions & 39 deletions python/ray/data/_internal/split.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import itertools
import logging
from typing import Iterable, Tuple, List

import ray
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import (
Block,
BlockPartition,
BlockAccessor,
BlockExecStats,
BlockMetadata,
Expand All @@ -15,7 +17,7 @@


def _calculate_blocks_rows(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
blocks_with_metadata: BlockPartition,
) -> List[int]:
"""Calculate the number of rows for a list of blocks with metadata."""
get_num_rows = cached_remote_fn(_get_num_rows)
Expand All @@ -24,6 +26,7 @@ def _calculate_blocks_rows(
if metadata.num_rows is None:
# Need to fetch number of rows.
num_rows = ray.get(get_num_rows.remote(block))
metadata.num_rows = num_rows
else:
num_rows = metadata.num_rows
block_rows.append(num_rows)
Expand Down Expand Up @@ -88,16 +91,15 @@ def _split_single_block(
block_id: int,
block: Block,
meta: BlockMetadata,
block_row: int,
split_indices: List[int],
) -> Tuple[int, List[Tuple[ObjectRef[Block], BlockMetadata]]]:
) -> Tuple[int, BlockPartition]:
"""Split the provided block at the given indices."""
split_result = []
block_accessor = BlockAccessor.for_block(block)
prev_index = 0
# append one more entry at the last so we don't
# need handle empty edge case.
split_indices.append(block_row)
split_indices.append(meta.num_rows)
for index in split_indices:
logger.debug(f"slicing block {prev_index}:{index}")
stats = BlockExecStats.builder()
Expand All @@ -115,23 +117,38 @@ def _split_single_block(
return (block_id, split_result)


def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have implication on public API Dataset.split_at_indices? It currently can generate empty blocks but this will change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about dropping only the ones for start/end of block?

Copy link
Contributor

@jianoaix jianoaix Jul 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm fine with completely dropping empty splits (unless they have a utility that I'm not aware of). In that case, we can just go to Dataset.split_at_indices and revamp the semantics there (we just need to de-dup the global indices and document the API that it will do so).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative: we do this internally as this PR, but when returning Dataset.split_at_indices we create empty splits to maintain the existing semantics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianoaix ah this is internal change but not change the split_at_indices public api. So what happens is previously we will have dataset who contains empty block (block with 0 rows), now we remove the block but the number of (split) dataset is still the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this PR does not change split_at_indices semantics.

Wondering would it be better just embed this function's logic inside _generate_per_block_split_indices() L70-72 ? We can just avoid adding these unuseful indices in the first place.

"""drop split indices that creates empty block split. This could happen when there
are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
(end of the block).
"""
prev_index = -1
optimized_indices = []
for index in block_split_indices:
if index == 0 or index == num_rows:
continue
if index == prev_index:
continue
optimized_indices.append(index)
prev_index = index
return optimized_indices


def _split_all_blocks(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
block_rows: List[int],
blocks_with_metadata: BlockPartition,
per_block_split_indices: List[List[int]],
) -> List[Tuple[ObjectRef[Block], BlockMetadata]]:
) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
"""Split all the input blocks based on the split indices"""
split_single_block = cached_remote_fn(_split_single_block)

all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]] = [
None
] * len(blocks_with_metadata)
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)

split_single_block_futures = []

for block_id, block_split_indices in enumerate(per_block_split_indices):
(block_ref, meta) = blocks_with_metadata[block_id]
block_row = block_rows[block_id]
block_row = meta.num_rows
block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
if len(block_split_indices) == 0:
# optimization: if no split is needed, we just need to add it to the
# result
Expand All @@ -143,46 +160,44 @@ def _split_all_blocks(
block_id,
block_ref,
meta,
block_row,
block_split_indices,
)
)
if split_single_block_futures:
split_single_block_results = ray.get(split_single_block_futures)
for block_id, block_split_result in split_single_block_results:
all_blocks_split_results[block_id] = block_split_result
return all_blocks_split_results
return itertools.chain.from_iterable(all_blocks_split_results)


def _generate_global_split_results(
all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]],
all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
global_split_sizes: List[int],
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
"""Reassemble per block's split result into final split result."""
result_blocks = []
result_metas = []

current_blocks = []
current_meta = []

if len(all_blocks_split_results) == 0:
return ([], [])

for single_block_split_result in all_blocks_split_results:
assert len(single_block_split_result) > 0
for i, (block, meta) in enumerate(single_block_split_result):
# we should create a new global split whenever
# we encountered a new local split in the per block
# split result.
if i != 0:
result_blocks.append(current_blocks)
result_metas.append(current_meta)
current_blocks = []
current_meta = []
current_blocks.append(block)
current_split_size = 0
current_split_id = 0

while current_split_id < len(global_split_sizes):
if current_split_size >= global_split_sizes[current_split_id]:
assert current_split_size == global_split_sizes[current_split_id]
result_blocks.append(current_blocks)
result_metas.append(current_meta)

current_blocks = []
current_meta = []
current_split_size = 0
current_split_id += 1
else:
(block_ref, meta) = next(all_blocks_split_results)
current_blocks.append(block_ref)
current_meta.append(meta)

assert len(current_blocks) > 0
result_blocks.append(current_blocks)
result_metas.append(current_meta)
current_split_size += meta.num_rows

return result_blocks, result_metas

Expand All @@ -205,20 +220,25 @@ def _split_at_indices(
# phase 1: calculate the per block split indices.
blocks_with_metadata = list(blocks_with_metadata)
if len(blocks_with_metadata) == 0:
return ([] * (len(indices) + 1), [] * (len(indices) + 1))
return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
block_rows: List[int] = _calculate_blocks_rows(blocks_with_metadata)
valid_indices = _generate_valid_indices(block_rows, indices)
per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
block_rows, valid_indices
)

# phase 2: split each block based on the indices from previous step.
all_blocks_split_results: List[
List[Tuple[ObjectRef[Block], BlockMetadata]]
] = _split_all_blocks(blocks_with_metadata, block_rows, per_block_split_indices)
all_blocks_split_results: Iterable[
Tuple[ObjectRef[Block], BlockMetadata]
] = _split_all_blocks(blocks_with_metadata, per_block_split_indices)

# phase 3: generate the final split.
return _generate_global_split_results(all_blocks_split_results)

# first calculate the size for each split.
helper = [0] + valid_indices + [sum(block_rows)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid_indices are the user-provided indices without deduplication. The final blocks are generated based on valid_indices, so we will still generate empty blocks as before if user provides duplicated indices, FYI @jianoaix and @matthewdeng.

split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]

return _generate_global_split_results(all_blocks_split_results, split_sizes)


def _get_num_rows(block: Block) -> int:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def test_repartition_noshuffle(ray_start_regular_shared):
blocks = ray.get(ds4.get_internal_block_refs())
assert all(isinstance(block, list) for block in blocks), blocks
assert ds4.sum() == 190
assert ds4._block_num_rows() == [0, 1] * 20
assert ds4._block_num_rows() == [1] * 20 + [0] * 20

ds5 = ray.data.range(22).repartition(4)
assert ds5.num_blocks() == 4
Expand Down
126 changes: 126 additions & 0 deletions python/ray/data/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

import numpy as np
import pytest
from ray.data.block import BlockMetadata

import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats
from ray.data._internal.split import (
_drop_empty_block_split,
_generate_valid_indices,
_generate_per_block_split_indices,
_generate_global_split_results,
_split_single_block,
_split_at_indices,
)
from ray.data.block import BlockAccessor
from ray.data.dataset import Dataset
Expand Down Expand Up @@ -486,3 +491,124 @@ def test_generate_per_block_split_indices():
[3, 3, 3, 1], [3, 10, 10]
)
assert [[], [], [], []] == _generate_per_block_split_indices([3, 3, 3, 1], [])


def _create_meta(num_rows):
return BlockMetadata(
num_rows=num_rows,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)


def _create_block(data):
return (ray.put(data), _create_meta(len(data)))


def test_split_single_block(ray_start_regular_shared):
block = [1, 2, 3]
meta = _create_meta(3)

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [])
)
assert 234 == block_id
assert len(splits) == 1
assert ray.get(splits[0][0]) == [1, 2, 3]
assert splits[0][1].num_rows == 3

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [1])
)
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == [1]
assert splits[0][1].num_rows == 1
assert ray.get(splits[1][0]) == [2, 3]
assert splits[1][1].num_rows == 2

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3])
)
assert 234 == block_id
assert len(splits) == 5
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == [1]
assert ray.get(splits[2][0]) == []
assert ray.get(splits[3][0]) == [2, 3]
assert ray.get(splits[4][0]) == []

block = []
meta = _create_meta(0)

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0])
)
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == []


def test_drop_empty_block_split():
assert [1, 2] == _drop_empty_block_split([0, 1, 2, 3], 3)
assert [1, 2] == _drop_empty_block_split([1, 1, 2, 2], 3)
assert [] == _drop_empty_block_split([0], 0)


def verify_splits(splits, blocks_by_split):
assert len(splits) == len(blocks_by_split)
for blocks, (block_refs, meta) in zip(blocks_by_split, splits):
assert len(blocks) == len(block_refs)
assert len(blocks) == len(meta)
for block, block_ref, meta in zip(blocks, block_refs, meta):
assert ray.get(block_ref) == block
assert meta.num_rows == len(block)


def test_generate_global_split_results(ray_start_regular_shared):
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]

splits = list(zip(*_generate_global_split_results(iter(inputs), [1, 2, 1])))
verify_splits(splits, [[[1]], [[2, 3]], [[4]]])

splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 1])))
verify_splits(splits, [[[1], [2, 3]], [[4]]])

splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 0, 1])))
verify_splits(splits, [[[1], [2, 3]], [], [[4]]])

inputs = []
splits = list(zip(*_generate_global_split_results(iter(inputs), [0, 0])))
verify_splits(splits, [[], []])


def test_private_split_at_indices(ray_start_regular_shared):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note we have very comprehensive test for test_split_at_indices. this is only for edge cases and block distributions.

inputs = []
splits = list(zip(*_split_at_indices(iter(inputs), [0])))
verify_splits(splits, [[], []])

splits = list(zip(*_split_at_indices(iter(inputs), [])))
verify_splits(splits, [[]])

inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]

splits = list(zip(*_split_at_indices(iter(inputs), [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [2])))
verify_splits(splits, [[[1], [2]], [[3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [2, 2])))
verify_splits(splits, [[[1], [2]], [], [[3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [])))
verify_splits(splits, [[[1], [2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [0, 4])))
verify_splits(splits, [[], [[1], [2, 3], [4]], []])