diff --git a/python/ray/data/_internal/split.py b/python/ray/data/_internal/split.py index 9536ec26d2e3..60ecc5100d61 100644 --- a/python/ray/data/_internal/split.py +++ b/python/ray/data/_internal/split.py @@ -1,3 +1,4 @@ +import itertools import logging from typing import Iterable, Tuple, List @@ -5,6 +6,7 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.data.block import ( Block, + BlockPartition, BlockAccessor, BlockExecStats, BlockMetadata, @@ -15,7 +17,7 @@ def _calculate_blocks_rows( - blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], + blocks_with_metadata: BlockPartition, ) -> List[int]: """Calculate the number of rows for a list of blocks with metadata.""" get_num_rows = cached_remote_fn(_get_num_rows) @@ -24,6 +26,7 @@ def _calculate_blocks_rows( if metadata.num_rows is None: # Need to fetch number of rows. num_rows = ray.get(get_num_rows.remote(block)) + metadata.num_rows = num_rows else: num_rows = metadata.num_rows block_rows.append(num_rows) @@ -88,16 +91,15 @@ def _split_single_block( block_id: int, block: Block, meta: BlockMetadata, - block_row: int, split_indices: List[int], -) -> Tuple[int, List[Tuple[ObjectRef[Block], BlockMetadata]]]: +) -> Tuple[int, BlockPartition]: """Split the provided block at the given indices.""" split_result = [] block_accessor = BlockAccessor.for_block(block) prev_index = 0 # append one more entry at the last so we don't # need handle empty edge case. - split_indices.append(block_row) + split_indices.append(meta.num_rows) for index in split_indices: logger.debug(f"slicing block {prev_index}:{index}") stats = BlockExecStats.builder() @@ -115,23 +117,38 @@ def _split_single_block( return (block_id, split_result) +def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]: + """drop split indices that creates empty block split. This could happen when there + are duplicated indices, or index equal to 0 (start of the block) or num_block_rows + (end of the block). + """ + prev_index = -1 + optimized_indices = [] + for index in block_split_indices: + if index == 0 or index == num_rows: + continue + if index == prev_index: + continue + optimized_indices.append(index) + prev_index = index + return optimized_indices + + def _split_all_blocks( - blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], - block_rows: List[int], + blocks_with_metadata: BlockPartition, per_block_split_indices: List[List[int]], -) -> List[Tuple[ObjectRef[Block], BlockMetadata]]: +) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]: """Split all the input blocks based on the split indices""" split_single_block = cached_remote_fn(_split_single_block) - all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]] = [ - None - ] * len(blocks_with_metadata) + all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata) split_single_block_futures = [] for block_id, block_split_indices in enumerate(per_block_split_indices): (block_ref, meta) = blocks_with_metadata[block_id] - block_row = block_rows[block_id] + block_row = meta.num_rows + block_split_indices = _drop_empty_block_split(block_split_indices, block_row) if len(block_split_indices) == 0: # optimization: if no split is needed, we just need to add it to the # result @@ -143,7 +160,6 @@ def _split_all_blocks( block_id, block_ref, meta, - block_row, block_split_indices, ) ) @@ -151,38 +167,37 @@ def _split_all_blocks( split_single_block_results = ray.get(split_single_block_futures) for block_id, block_split_result in split_single_block_results: all_blocks_split_results[block_id] = block_split_result - return all_blocks_split_results + return itertools.chain.from_iterable(all_blocks_split_results) def _generate_global_split_results( - all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]], + all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]], + global_split_sizes: List[int], ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]: """Reassemble per block's split result into final split result.""" result_blocks = [] result_metas = [] + current_blocks = [] current_meta = [] - - if len(all_blocks_split_results) == 0: - return ([], []) - - for single_block_split_result in all_blocks_split_results: - assert len(single_block_split_result) > 0 - for i, (block, meta) in enumerate(single_block_split_result): - # we should create a new global split whenever - # we encountered a new local split in the per block - # split result. - if i != 0: - result_blocks.append(current_blocks) - result_metas.append(current_meta) - current_blocks = [] - current_meta = [] - current_blocks.append(block) + current_split_size = 0 + current_split_id = 0 + + while current_split_id < len(global_split_sizes): + if current_split_size >= global_split_sizes[current_split_id]: + assert current_split_size == global_split_sizes[current_split_id] + result_blocks.append(current_blocks) + result_metas.append(current_meta) + + current_blocks = [] + current_meta = [] + current_split_size = 0 + current_split_id += 1 + else: + (block_ref, meta) = next(all_blocks_split_results) + current_blocks.append(block_ref) current_meta.append(meta) - - assert len(current_blocks) > 0 - result_blocks.append(current_blocks) - result_metas.append(current_meta) + current_split_size += meta.num_rows return result_blocks, result_metas @@ -205,7 +220,7 @@ def _split_at_indices( # phase 1: calculate the per block split indices. blocks_with_metadata = list(blocks_with_metadata) if len(blocks_with_metadata) == 0: - return ([] * (len(indices) + 1), [] * (len(indices) + 1)) + return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1)) block_rows: List[int] = _calculate_blocks_rows(blocks_with_metadata) valid_indices = _generate_valid_indices(block_rows, indices) per_block_split_indices: List[List[int]] = _generate_per_block_split_indices( @@ -213,12 +228,17 @@ def _split_at_indices( ) # phase 2: split each block based on the indices from previous step. - all_blocks_split_results: List[ - List[Tuple[ObjectRef[Block], BlockMetadata]] - ] = _split_all_blocks(blocks_with_metadata, block_rows, per_block_split_indices) + all_blocks_split_results: Iterable[ + Tuple[ObjectRef[Block], BlockMetadata] + ] = _split_all_blocks(blocks_with_metadata, per_block_split_indices) # phase 3: generate the final split. - return _generate_global_split_results(all_blocks_split_results) + + # first calculate the size for each split. + helper = [0] + valid_indices + [sum(block_rows)] + split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))] + + return _generate_global_split_results(all_blocks_split_results, split_sizes) def _get_num_rows(block: Block) -> int: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 74bafd715757..a3fa013112a2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1392,7 +1392,7 @@ def test_repartition_noshuffle(ray_start_regular_shared): blocks = ray.get(ds4.get_internal_block_refs()) assert all(isinstance(block, list) for block in blocks), blocks assert ds4.sum() == 190 - assert ds4._block_num_rows() == [0, 1] * 20 + assert ds4._block_num_rows() == [1] * 20 + [0] * 20 ds5 = ray.data.range(22).repartition(4) assert ds5.num_blocks() == 4 diff --git a/python/ray/data/tests/test_split.py b/python/ray/data/tests/test_split.py index fb15c6303fc1..e54a3f018d36 100644 --- a/python/ray/data/tests/test_split.py +++ b/python/ray/data/tests/test_split.py @@ -6,14 +6,19 @@ import numpy as np import pytest +from ray.data.block import BlockMetadata import ray from ray.data._internal.block_list import BlockList from ray.data._internal.plan import ExecutionPlan from ray.data._internal.stats import DatasetStats from ray.data._internal.split import ( + _drop_empty_block_split, _generate_valid_indices, _generate_per_block_split_indices, + _generate_global_split_results, + _split_single_block, + _split_at_indices, ) from ray.data.block import BlockAccessor from ray.data.dataset import Dataset @@ -486,3 +491,124 @@ def test_generate_per_block_split_indices(): [3, 3, 3, 1], [3, 10, 10] ) assert [[], [], [], []] == _generate_per_block_split_indices([3, 3, 3, 1], []) + + +def _create_meta(num_rows): + return BlockMetadata( + num_rows=num_rows, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + + +def _create_block(data): + return (ray.put(data), _create_meta(len(data))) + + +def test_split_single_block(ray_start_regular_shared): + block = [1, 2, 3] + meta = _create_meta(3) + + block_id, splits = ray.get( + ray.remote(_split_single_block).remote(234, block, meta, []) + ) + assert 234 == block_id + assert len(splits) == 1 + assert ray.get(splits[0][0]) == [1, 2, 3] + assert splits[0][1].num_rows == 3 + + block_id, splits = ray.get( + ray.remote(_split_single_block).remote(234, block, meta, [1]) + ) + assert 234 == block_id + assert len(splits) == 2 + assert ray.get(splits[0][0]) == [1] + assert splits[0][1].num_rows == 1 + assert ray.get(splits[1][0]) == [2, 3] + assert splits[1][1].num_rows == 2 + + block_id, splits = ray.get( + ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3]) + ) + assert 234 == block_id + assert len(splits) == 5 + assert ray.get(splits[0][0]) == [] + assert ray.get(splits[1][0]) == [1] + assert ray.get(splits[2][0]) == [] + assert ray.get(splits[3][0]) == [2, 3] + assert ray.get(splits[4][0]) == [] + + block = [] + meta = _create_meta(0) + + block_id, splits = ray.get( + ray.remote(_split_single_block).remote(234, block, meta, [0]) + ) + assert 234 == block_id + assert len(splits) == 2 + assert ray.get(splits[0][0]) == [] + assert ray.get(splits[1][0]) == [] + + +def test_drop_empty_block_split(): + assert [1, 2] == _drop_empty_block_split([0, 1, 2, 3], 3) + assert [1, 2] == _drop_empty_block_split([1, 1, 2, 2], 3) + assert [] == _drop_empty_block_split([0], 0) + + +def verify_splits(splits, blocks_by_split): + assert len(splits) == len(blocks_by_split) + for blocks, (block_refs, meta) in zip(blocks_by_split, splits): + assert len(blocks) == len(block_refs) + assert len(blocks) == len(meta) + for block, block_ref, meta in zip(blocks, block_refs, meta): + assert ray.get(block_ref) == block + assert meta.num_rows == len(block) + + +def test_generate_global_split_results(ray_start_regular_shared): + inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])] + + splits = list(zip(*_generate_global_split_results(iter(inputs), [1, 2, 1]))) + verify_splits(splits, [[[1]], [[2, 3]], [[4]]]) + + splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 1]))) + verify_splits(splits, [[[1], [2, 3]], [[4]]]) + + splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 0, 1]))) + verify_splits(splits, [[[1], [2, 3]], [], [[4]]]) + + inputs = [] + splits = list(zip(*_generate_global_split_results(iter(inputs), [0, 0]))) + verify_splits(splits, [[], []]) + + +def test_private_split_at_indices(ray_start_regular_shared): + inputs = [] + splits = list(zip(*_split_at_indices(iter(inputs), [0]))) + verify_splits(splits, [[], []]) + + splits = list(zip(*_split_at_indices(iter(inputs), []))) + verify_splits(splits, [[]]) + + inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])] + + splits = list(zip(*_split_at_indices(iter(inputs), [1]))) + verify_splits(splits, [[[1]], [[2, 3], [4]]]) + + splits = list(zip(*_split_at_indices(iter(inputs), [2]))) + verify_splits(splits, [[[1], [2]], [[3], [4]]]) + + splits = list(zip(*_split_at_indices(iter(inputs), [1]))) + verify_splits(splits, [[[1]], [[2, 3], [4]]]) + + splits = list(zip(*_split_at_indices(iter(inputs), [2, 2]))) + verify_splits(splits, [[[1], [2]], [], [[3], [4]]]) + + splits = list(zip(*_split_at_indices(iter(inputs), []))) + verify_splits(splits, [[[1], [2, 3], [4]]]) + + splits = list(zip(*_split_at_indices(iter(inputs), [0, 4]))) + verify_splits(splits, [[], [[1], [2, 3], [4]], []])