Skip to content

Commit

Permalink
[Data] Add override_num_blocks to from_pandas and perform auto-pa…
Browse files Browse the repository at this point in the history
…rtition (#44937)

A common pattern is to load a DataFrame containing file URIs with from_pandas and then loading those URIs with map_batches. If you have a single large DataFrame, the subsequent operator (e.g., for reading) won't be parallelized because from_pandas produces one input block.

To fix this issue, this PR automatically splits DataFrames into a good number of blocks, and allows the user to override the number of blocks.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored May 25, 2024
1 parent 91a6ed2 commit f13d144
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 23 deletions.
1 change: 1 addition & 0 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ray.data.read_api import ( # noqa: F401
from_arrow,
from_arrow_refs,
from_blocks,
from_dask,
from_huggingface,
from_items,
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/_internal/logical/operators/from_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class FromItems(AbstractFrom):
pass


class FromBlocks(AbstractFrom):
"""Logical operator for `from_blocks`."""

pass


class FromNumpy(AbstractFrom):
"""Logical operator for `from_numpy`."""

Expand Down
34 changes: 33 additions & 1 deletion python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def to_numpy(
def to_arrow(self) -> "pyarrow.Table":
import pyarrow

return pyarrow.table(self._table)
# Set `preserve_index=False` so that Arrow doesn't add a '__index_level_0__'
# column to the resulting table.
return pyarrow.Table.from_pandas(self._table, preserve_index=False)

@staticmethod
def numpy_to_block(
Expand Down Expand Up @@ -632,3 +634,33 @@ def gen():

def block_type(self) -> BlockType:
return BlockType.PANDAS


def _estimate_dataframe_size(df: "pandas.DataFrame") -> int:
"""Estimate the size of a pandas DataFrame.
This function is necessary because `DataFrame.memory_usage` doesn't count values in
columns with `dtype=object`.
The runtime complexity is linear in the number of values, so don't use this in
performance-critical code.
Args:
df: The DataFrame to estimate the size of.
Returns:
The estimated size of the DataFrame in bytes.
"""
size = 0
for column in df.columns:
if df[column].dtype == object:
for item in df[column]:
if isinstance(item, str):
size += len(item)
elif isinstance(item, np.ndarray):
size += item.nbytes
else:
size += 8 # pandas assumes object values are 8 bytes.
else:
size += df[column].nbytes
return size
4 changes: 1 addition & 3 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,8 @@ def capitalize(s: str):
def pandas_df_to_arrow_block(df: "pandas.DataFrame") -> "Block":
from ray.data.block import BlockAccessor, BlockExecStats

block = BlockAccessor.for_block(df).to_arrow()
stats = BlockExecStats.builder()
import pyarrow as pa

block = pa.table(df)
return (
block,
BlockAccessor.for_block(block).get_metadata(
Expand Down
59 changes: 56 additions & 3 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import logging
import math
import os
import warnings
from typing import (
Expand All @@ -25,12 +26,14 @@
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.logical.operators.from_operators import (
FromArrow,
FromBlocks,
FromItems,
FromNumpy,
FromPandas,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.optimizers import LogicalPlan
from ray.data._internal.pandas_block import _estimate_dataframe_size
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats
Expand Down Expand Up @@ -104,6 +107,37 @@
logger = logging.getLogger(__name__)


@DeveloperAPI
def from_blocks(blocks: List[Block]):
"""Create a :class:`~ray.data.Dataset` from a list of blocks.
This method is primarily used for testing. Unlike other methods like
:func:`~ray.data.from_pandas` and :func:`~ray.data.from_arrow`, this method
gaurentees that it won't modify the number of blocks.
Args:
blocks: List of blocks to create the dataset from.
Returns:
A :class:`~ray.data.Dataset` holding the blocks.
"""
block_refs = [ray.put(block) for block in blocks]
metadata = [
BlockAccessor.for_block(block).get_metadata(input_files=None, exec_stats=None)
for block in blocks
]
from_blocks_op = FromBlocks(block_refs, metadata)
logical_plan = LogicalPlan(from_blocks_op)
return MaterializedDataset(
ExecutionPlan(
BlockList(block_refs, metadata, owned_by_consumer=False),
DatasetStats(metadata={"FromBlocks": metadata}, parent=None),
run_by_consumer=False,
),
logical_plan,
)


@PublicAPI
def from_items(
items: List[Any],
Expand Down Expand Up @@ -2359,7 +2393,8 @@ def from_modin(df: "modin.pandas.dataframe.DataFrame") -> MaterializedDataset:

@PublicAPI
def from_pandas(
dfs: Union["pandas.DataFrame", List["pandas.DataFrame"]]
dfs: Union["pandas.DataFrame", List["pandas.DataFrame"]],
override_num_blocks: Optional[int] = None,
) -> MaterializedDataset:
"""Create a :class:`~ray.data.Dataset` from a list of pandas dataframes.
Expand All @@ -2373,10 +2408,14 @@ def from_pandas(
Create a Ray Dataset from a list of Pandas DataFrames.
>>> ray.data.from_pandas([df, df])
MaterializedDataset(num_blocks=2, num_rows=6, schema={a: int64, b: int64})
MaterializedDataset(num_blocks=1, num_rows=6, schema={a: int64, b: int64})
Args:
dfs: A pandas dataframe or a list of pandas dataframes.
override_num_blocks: Override the number of output blocks from all read tasks.
By default, the number of output blocks is dynamically decided based on
input data size and available resources. You shouldn't manually set this
value in most cases.
Returns:
:class:`~ray.data.Dataset` holding data read from the dataframes.
Expand All @@ -2386,13 +2425,27 @@ def from_pandas(
if isinstance(dfs, pd.DataFrame):
dfs = [dfs]

context = DataContext.get_current()
num_blocks = override_num_blocks
if num_blocks is None:
total_size = sum(_estimate_dataframe_size(df) for df in dfs)
num_blocks = max(math.ceil(total_size / context.target_max_block_size), 1)

if len(dfs) > 1:
# I assume most users pass a single DataFrame as input. For simplicity, I'm
# concatenating DataFrames, even though it's not efficient.
ary = pd.concat(dfs, axis=0)
else:
ary = dfs[0]
dfs = np.array_split(ary, num_blocks)

from ray.air.util.data_batch_conversion import (
_cast_ndarray_columns_to_tensor_extension,
)

context = DataContext.get_current()
if context.enable_tensor_extension_casting:
dfs = [_cast_ndarray_columns_to_tensor_extension(df.copy()) for df in dfs]

return from_pandas_refs([ray.put(df) for df in dfs])


Expand Down
25 changes: 23 additions & 2 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,27 @@ def test_convert_types(ray_start_regular_shared):
assert arrow_ds.map(lambda x: {"a": (x["id"],)}).take() == [{"a": [0]}]


@pytest.mark.parametrize(
"input_blocks",
[
[pd.DataFrame({"column": ["spam"]}), pd.DataFrame({"column": ["ham", "eggs"]})],
[
pa.Table.from_pydict({"column": ["spam"]}),
pa.Table.from_pydict({"column": ["ham", "eggs"]}),
],
],
)
def test_from_blocks(input_blocks, ray_start_regular_shared):
ds = ray.data.from_blocks(input_blocks)

output_blocks = [ray.get(block_ref) for block_ref in ds.get_internal_block_refs()]
assert len(input_blocks) == len(output_blocks)
assert all(
input_block.equals(output_block)
for input_block, output_block in zip(input_blocks, output_blocks)
)


def test_from_items(ray_start_regular_shared):
ds = ray.data.from_items(["hello", "world"])
assert extract_values("item", ds.take()) == ["hello", "world"]
Expand Down Expand Up @@ -781,7 +802,7 @@ def test_iter_batches_basic(ray_start_regular_shared):
df3 = pd.DataFrame({"one": [7, 8, 9], "two": [8, 9, 10]})
df4 = pd.DataFrame({"one": [10, 11, 12], "two": [11, 12, 13]})
dfs = [df1, df2, df3, df4]
ds = ray.data.from_pandas(dfs)
ds = ray.data.from_blocks(dfs)

# Default.
for batch, df in zip(ds.iter_batches(batch_size=None, batch_format="pandas"), dfs):
Expand Down Expand Up @@ -1179,7 +1200,7 @@ def test_iter_batches_grid(ray_start_regular_shared):
)
running_size += block_size
num_rows = running_size
ds = ray.data.from_pandas(dfs)
ds = ray.data.from_blocks(dfs)
for batch_size in np.random.randint(
1, num_rows + 1, size=batch_size_samples
):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,15 +689,15 @@ def test_csv_write(ray_start_regular_shared, fs, data_path, endpoint_url):
storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url))
# Single block.
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
ds = ray.data.from_pandas([df1])
ds = ray.data.from_blocks([df1])
ds._set_uuid("data")
ds.write_csv(data_path, filesystem=fs)
file_path = os.path.join(data_path, "data_000000_000000.csv")
assert df1.equals(pd.read_csv(file_path, storage_options=storage_options))

# Two blocks.
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
ds._set_uuid("data")
ds.write_csv(data_path, filesystem=fs)
file_path2 = os.path.join(data_path, "data_000001_000000.csv")
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_to_dask(ray_start_regular_shared, ds_format):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
if ds_format == "arrow":
ds = ds.map_batches(lambda df: df, batch_format="pyarrow", batch_size=None)
ddf = ds.to_dask()
Expand All @@ -52,7 +52,7 @@ def test_to_dask(ray_start_regular_shared, ds_format):
df1["two"] = df1["two"].astype(pd.StringDtype())
df2["two"] = df2["two"].astype(pd.StringDtype())
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
if ds_format == "arrow":
ds = ds.map_batches(lambda df: df, batch_format="pyarrow", batch_size=None)
ddf = ds.to_dask(
Expand All @@ -76,7 +76,7 @@ def test_to_dask(ray_start_regular_shared, ds_format):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"three": [4, 5, 6], "four": ["e", "f", "g"]})
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
if ds_format == "arrow":
ds = ds.map_batches(lambda df: df, batch_format="pyarrow", batch_size=None)
ddf = ds.to_dask(verify_meta=False)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url):
storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url))
# Single block.
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
ds = ray.data.from_pandas([df1])
ds = ray.data.from_blocks([df1])
ds._set_uuid("data")
ds.write_json(data_path, filesystem=fs)
file_path = os.path.join(data_path, "data_000000_000000.json")
Expand All @@ -500,7 +500,7 @@ def test_json_write(ray_start_regular_shared, fs, data_path, endpoint_url):

# Two blocks.
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
ds._set_uuid("data")
ds.write_json(data_path, filesystem=fs)
file_path2 = os.path.join(data_path, "data_000001_000000.json")
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def test_map_batches_block_bundling_skewed_manual(
ray_start_regular_shared, block_sizes, batch_size, expected_num_blocks
):
num_blocks = len(block_sizes)
ds = ray.data.from_pandas(
ds = ray.data.from_blocks(
[pd.DataFrame({"a": [1] * block_size}) for block_size in block_sizes]
)
# Confirm that we have the expected number of initial blocks.
Expand All @@ -856,7 +856,7 @@ def test_map_batches_block_bundling_skewed_auto(
ray_start_regular_shared, block_sizes, batch_size
):
num_blocks = len(block_sizes)
ds = ray.data.from_pandas(
ds = ray.data.from_blocks(
[pd.DataFrame({"a": [1] * block_size}) for block_size in block_sizes]
)
# Confirm that we have the expected number of initial blocks.
Expand Down
27 changes: 26 additions & 1 deletion python/ray/data/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ def test_from_pandas(ray_start_regular_shared, enable_pandas_block):
ctx.enable_pandas_block = old_enable_pandas_block


def test_from_pandas_default_num_blocks(ray_start_regular_shared, restore_data_context):
ray.data.DataContext.get_current().target_max_block_size = 8 * 1024 * 1024 # 8 MiB

record = {"number": 0, "string": "\0"}
record_size_bytes = 8 + 1 # 8 bytes for int64 and 1 byte for char
dataframe_size_bytes = 64 * 1024 * 1024 # 64 MiB
num_records = int(dataframe_size_bytes / record_size_bytes)
df = pd.DataFrame.from_records([record] * num_records)

ds = ray.data.from_pandas(df)

# If the target block size is 8 MiB, the DataFrame should be split into
# 64 MiB / (8 MiB / block) = 8 blocks.
assert ds.materialize().num_blocks() == 8


@pytest.mark.parametrize("num_inputs", [1, 2])
def test_from_pandas_override_num_blocks(num_inputs, ray_start_regular_shared):
df = pd.DataFrame({"number": [0]})

ds = ray.data.from_pandas([df] * num_inputs, override_num_blocks=2)

assert ds.materialize().num_blocks() == 2


@pytest.mark.parametrize("enable_pandas_block", [False, True])
def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block):
ctx = ray.data.context.DataContext.get_current()
Expand Down Expand Up @@ -113,7 +138,7 @@ def test_to_pandas_refs(ray_start_regular_shared):
def test_pandas_roundtrip(ray_start_regular_shared, tmp_path):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_pandas([df1, df2], override_num_blocks=2)
dfds = ds.to_pandas()
assert pd.concat([df1, df2], ignore_index=True).equals(dfds)

Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
path = os.path.join(data_path, "test_parquet_dir")
if fs is None:
os.mkdir(path)
Expand Down Expand Up @@ -928,7 +928,7 @@ def test_parquet_write_create_dir(
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = ray.data.from_blocks([df1, df2])
path = os.path.join(data_path, "test_parquet_dir")
# Set the uuid to a known value so that we can easily get the parquet file names.
data_key = "data"
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_sort_arrow(
offset += shard
if offset < num_items:
dfs.append(pd.DataFrame({"a": a[offset:], "b": b[offset:]}))
ds = ray.data.from_pandas(dfs).map_batches(
ds = ray.data.from_blocks(dfs).map_batches(
lambda t: t, batch_format="pyarrow", batch_size=None
)

Expand Down Expand Up @@ -235,7 +235,7 @@ def test_sort_pandas(ray_start_regular, num_items, parallelism, use_push_based_s
offset += shard
if offset < num_items:
dfs.append(pd.DataFrame({"a": a[offset:], "b": b[offset:]}))
ds = ray.data.from_pandas(dfs)
ds = ray.data.from_blocks(dfs)

def assert_sorted(sorted_ds, expected_rows):
assert [tuple(row.values()) for row in sorted_ds.iter_rows()] == list(
Expand Down

0 comments on commit f13d144

Please sign in to comment.