Skip to content

Commit

Permalink
[Data][Split optimization] don't generate empty blocks (ray-project#2…
Browse files Browse the repository at this point in the history
…6768)

The current split_at_index might generate empty blocks and also trigger unnecessary split task. The empty blocks happens when there are duplicate split indices, or the split index falls at the block boundaries. The unnecessary split tasks are triggered when the split index falls at the block boundaries.

This PR fix that by checking if the split index is duplicated or falls at the boundaries of blocks. in that case, we could safely ignore those indices.

Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
scv119 authored and Stefan van der Kleij committed Aug 18, 2022
1 parent e8e0c4e commit 02e98ab
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 40 deletions.
98 changes: 59 additions & 39 deletions python/ray/data/_internal/split.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import itertools
import logging
from typing import Iterable, Tuple, List

import ray
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import (
Block,
BlockPartition,
BlockAccessor,
BlockExecStats,
BlockMetadata,
Expand All @@ -15,7 +17,7 @@


def _calculate_blocks_rows(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
blocks_with_metadata: BlockPartition,
) -> List[int]:
"""Calculate the number of rows for a list of blocks with metadata."""
get_num_rows = cached_remote_fn(_get_num_rows)
Expand All @@ -24,6 +26,7 @@ def _calculate_blocks_rows(
if metadata.num_rows is None:
# Need to fetch number of rows.
num_rows = ray.get(get_num_rows.remote(block))
metadata.num_rows = num_rows
else:
num_rows = metadata.num_rows
block_rows.append(num_rows)
Expand Down Expand Up @@ -88,16 +91,15 @@ def _split_single_block(
block_id: int,
block: Block,
meta: BlockMetadata,
block_row: int,
split_indices: List[int],
) -> Tuple[int, List[Tuple[ObjectRef[Block], BlockMetadata]]]:
) -> Tuple[int, BlockPartition]:
"""Split the provided block at the given indices."""
split_result = []
block_accessor = BlockAccessor.for_block(block)
prev_index = 0
# append one more entry at the last so we don't
# need handle empty edge case.
split_indices.append(block_row)
split_indices.append(meta.num_rows)
for index in split_indices:
logger.debug(f"slicing block {prev_index}:{index}")
stats = BlockExecStats.builder()
Expand All @@ -115,23 +117,38 @@ def _split_single_block(
return (block_id, split_result)


def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
"""drop split indices that creates empty block split. This could happen when there
are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
(end of the block).
"""
prev_index = -1
optimized_indices = []
for index in block_split_indices:
if index == 0 or index == num_rows:
continue
if index == prev_index:
continue
optimized_indices.append(index)
prev_index = index
return optimized_indices


def _split_all_blocks(
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
block_rows: List[int],
blocks_with_metadata: BlockPartition,
per_block_split_indices: List[List[int]],
) -> List[Tuple[ObjectRef[Block], BlockMetadata]]:
) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
"""Split all the input blocks based on the split indices"""
split_single_block = cached_remote_fn(_split_single_block)

all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]] = [
None
] * len(blocks_with_metadata)
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)

split_single_block_futures = []

for block_id, block_split_indices in enumerate(per_block_split_indices):
(block_ref, meta) = blocks_with_metadata[block_id]
block_row = block_rows[block_id]
block_row = meta.num_rows
block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
if len(block_split_indices) == 0:
# optimization: if no split is needed, we just need to add it to the
# result
Expand All @@ -143,46 +160,44 @@ def _split_all_blocks(
block_id,
block_ref,
meta,
block_row,
block_split_indices,
)
)
if split_single_block_futures:
split_single_block_results = ray.get(split_single_block_futures)
for block_id, block_split_result in split_single_block_results:
all_blocks_split_results[block_id] = block_split_result
return all_blocks_split_results
return itertools.chain.from_iterable(all_blocks_split_results)


def _generate_global_split_results(
all_blocks_split_results: List[List[Tuple[ObjectRef[Block], BlockMetadata]]],
all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
global_split_sizes: List[int],
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
"""Reassemble per block's split result into final split result."""
result_blocks = []
result_metas = []

current_blocks = []
current_meta = []

if len(all_blocks_split_results) == 0:
return ([], [])

for single_block_split_result in all_blocks_split_results:
assert len(single_block_split_result) > 0
for i, (block, meta) in enumerate(single_block_split_result):
# we should create a new global split whenever
# we encountered a new local split in the per block
# split result.
if i != 0:
result_blocks.append(current_blocks)
result_metas.append(current_meta)
current_blocks = []
current_meta = []
current_blocks.append(block)
current_split_size = 0
current_split_id = 0

while current_split_id < len(global_split_sizes):
if current_split_size >= global_split_sizes[current_split_id]:
assert current_split_size == global_split_sizes[current_split_id]
result_blocks.append(current_blocks)
result_metas.append(current_meta)

current_blocks = []
current_meta = []
current_split_size = 0
current_split_id += 1
else:
(block_ref, meta) = next(all_blocks_split_results)
current_blocks.append(block_ref)
current_meta.append(meta)

assert len(current_blocks) > 0
result_blocks.append(current_blocks)
result_metas.append(current_meta)
current_split_size += meta.num_rows

return result_blocks, result_metas

Expand All @@ -205,20 +220,25 @@ def _split_at_indices(
# phase 1: calculate the per block split indices.
blocks_with_metadata = list(blocks_with_metadata)
if len(blocks_with_metadata) == 0:
return ([] * (len(indices) + 1), [] * (len(indices) + 1))
return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
block_rows: List[int] = _calculate_blocks_rows(blocks_with_metadata)
valid_indices = _generate_valid_indices(block_rows, indices)
per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
block_rows, valid_indices
)

# phase 2: split each block based on the indices from previous step.
all_blocks_split_results: List[
List[Tuple[ObjectRef[Block], BlockMetadata]]
] = _split_all_blocks(blocks_with_metadata, block_rows, per_block_split_indices)
all_blocks_split_results: Iterable[
Tuple[ObjectRef[Block], BlockMetadata]
] = _split_all_blocks(blocks_with_metadata, per_block_split_indices)

# phase 3: generate the final split.
return _generate_global_split_results(all_blocks_split_results)

# first calculate the size for each split.
helper = [0] + valid_indices + [sum(block_rows)]
split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]

return _generate_global_split_results(all_blocks_split_results, split_sizes)


def _get_num_rows(block: Block) -> int:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def test_repartition_noshuffle(ray_start_regular_shared):
blocks = ray.get(ds4.get_internal_block_refs())
assert all(isinstance(block, list) for block in blocks), blocks
assert ds4.sum() == 190
assert ds4._block_num_rows() == [0, 1] * 20
assert ds4._block_num_rows() == [1] * 20 + [0] * 20

ds5 = ray.data.range(22).repartition(4)
assert ds5.num_blocks() == 4
Expand Down
126 changes: 126 additions & 0 deletions python/ray/data/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

import numpy as np
import pytest
from ray.data.block import BlockMetadata

import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats
from ray.data._internal.split import (
_drop_empty_block_split,
_generate_valid_indices,
_generate_per_block_split_indices,
_generate_global_split_results,
_split_single_block,
_split_at_indices,
)
from ray.data.block import BlockAccessor
from ray.data.dataset import Dataset
Expand Down Expand Up @@ -486,3 +491,124 @@ def test_generate_per_block_split_indices():
[3, 3, 3, 1], [3, 10, 10]
)
assert [[], [], [], []] == _generate_per_block_split_indices([3, 3, 3, 1], [])


def _create_meta(num_rows):
return BlockMetadata(
num_rows=num_rows,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)


def _create_block(data):
return (ray.put(data), _create_meta(len(data)))


def test_split_single_block(ray_start_regular_shared):
block = [1, 2, 3]
meta = _create_meta(3)

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [])
)
assert 234 == block_id
assert len(splits) == 1
assert ray.get(splits[0][0]) == [1, 2, 3]
assert splits[0][1].num_rows == 3

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [1])
)
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == [1]
assert splits[0][1].num_rows == 1
assert ray.get(splits[1][0]) == [2, 3]
assert splits[1][1].num_rows == 2

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3])
)
assert 234 == block_id
assert len(splits) == 5
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == [1]
assert ray.get(splits[2][0]) == []
assert ray.get(splits[3][0]) == [2, 3]
assert ray.get(splits[4][0]) == []

block = []
meta = _create_meta(0)

block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0])
)
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == []


def test_drop_empty_block_split():
assert [1, 2] == _drop_empty_block_split([0, 1, 2, 3], 3)
assert [1, 2] == _drop_empty_block_split([1, 1, 2, 2], 3)
assert [] == _drop_empty_block_split([0], 0)


def verify_splits(splits, blocks_by_split):
assert len(splits) == len(blocks_by_split)
for blocks, (block_refs, meta) in zip(blocks_by_split, splits):
assert len(blocks) == len(block_refs)
assert len(blocks) == len(meta)
for block, block_ref, meta in zip(blocks, block_refs, meta):
assert ray.get(block_ref) == block
assert meta.num_rows == len(block)


def test_generate_global_split_results(ray_start_regular_shared):
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]

splits = list(zip(*_generate_global_split_results(iter(inputs), [1, 2, 1])))
verify_splits(splits, [[[1]], [[2, 3]], [[4]]])

splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 1])))
verify_splits(splits, [[[1], [2, 3]], [[4]]])

splits = list(zip(*_generate_global_split_results(iter(inputs), [3, 0, 1])))
verify_splits(splits, [[[1], [2, 3]], [], [[4]]])

inputs = []
splits = list(zip(*_generate_global_split_results(iter(inputs), [0, 0])))
verify_splits(splits, [[], []])


def test_private_split_at_indices(ray_start_regular_shared):
inputs = []
splits = list(zip(*_split_at_indices(iter(inputs), [0])))
verify_splits(splits, [[], []])

splits = list(zip(*_split_at_indices(iter(inputs), [])))
verify_splits(splits, [[]])

inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]

splits = list(zip(*_split_at_indices(iter(inputs), [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [2])))
verify_splits(splits, [[[1], [2]], [[3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [2, 2])))
verify_splits(splits, [[[1], [2]], [], [[3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [])))
verify_splits(splits, [[[1], [2, 3], [4]]])

splits = list(zip(*_split_at_indices(iter(inputs), [0, 4])))
verify_splits(splits, [[], [[1], [2, 3], [4]], []])

0 comments on commit 02e98ab

Please sign in to comment.