Skip to content

Commit

Permalink
Make a pass fixing Dataset API issues (#22886)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Mar 8, 2022
1 parent ab2741d commit 52491c8
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 291 deletions.
6 changes: 3 additions & 3 deletions doc/source/data/dataset-tensor-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ If you already have a Parquet dataset with columns containing serialized tensors
# Read the Parquet files into a new Dataset, with the serialized tensors
# automatically cast to our tensor column extension type.
ds = ray.data.read_parquet(
path, _tensor_column_schema={"two": (np.int, (2, 2, 2))})
path, tensor_column_schema={"two": (np.int, (2, 2, 2))})
# Internally, this column is represented with our Arrow tensor extension
# type.
Expand Down Expand Up @@ -107,7 +107,7 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>
Please note that the ``_tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
Please note that the ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.

Working with tensor column datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -245,6 +245,6 @@ Limitations
This feature currently comes with a few known limitations that we are either actively working on addressing or have already implemented workarounds for.

* All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here <https://github.com/ray-project/ray/issues/18316>`__.
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``_tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
* Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue <https://github.com/pytorch/pytorch/issues/51156>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18314>`__.
* Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR <https://github.com/pandas-dev/pandas/pull/43160>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18315>`__.
5 changes: 1 addition & 4 deletions doc/source/data/examples/big_data_ingestion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,14 @@
"def create_large_shuffle_pipeline(\n",
" data_size_bytes: int, num_epochs: int, num_columns: int, num_shards: int\n",
") -> List[DatasetPipeline]:\n",
" # _spread_resource_prefix is used to ensure tasks are evenly spread to all\n",
" # CPU nodes.\n",
" return (\n",
" ray.data.read_datasource(\n",
" RandomIntRowDatasource(),\n",
" n=data_size_bytes // 8 // num_columns,\n",
" num_columns=num_columns,\n",
" _spread_resource_prefix=\"node:\",\n",
" )\n",
" .repeat(num_epochs)\n",
" .random_shuffle_each_window(_spread_resource_prefix=\"node:\")\n",
" .random_shuffle_each_window()\n",
" .split(num_shards, equal=True)\n",
" )"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def create_data_chunk(n, d, seed, include_label=False):

def read_dataset(path: str) -> ray.data.Dataset:
print(f"reading data from {path}")
return ray.data.read_parquet(path, _spread_resource_prefix="node:").random_shuffle(
_spread_resource_prefix="node:"
)
return ray.data.read_parquet(path).random_shuffle()


class DataPreprocessor:
Expand Down Expand Up @@ -596,9 +594,7 @@ def train_func(config):
DROPOUT_PROB = 0.2

# Random global shuffle
train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window(
_spread_resource_prefix="node:"
)
train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window()
del train_dataset

datasets = {"train_dataset": train_dataset_pipeline, "test_dataset": test_dataset}
Expand Down
14 changes: 7 additions & 7 deletions python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.data import Dataset


@PublicAPI(stability="beta")
@PublicAPI
class AggregateFn(object):
def __init__(
self,
Expand Down Expand Up @@ -64,7 +64,7 @@ def _validate(self, ds: "Dataset") -> None:
_validate_key_fn(ds, self._key_fn)


@PublicAPI(stability="beta")
@PublicAPI
class Count(AggregateFn):
"""Defines count aggregation."""

Expand All @@ -77,7 +77,7 @@ def __init__(self):
)


@PublicAPI(stability="beta")
@PublicAPI
class Sum(_AggregateOnKeyBase):
"""Defines sum aggregation."""

Expand All @@ -94,7 +94,7 @@ def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
)


@PublicAPI(stability="beta")
@PublicAPI
class Min(_AggregateOnKeyBase):
"""Defines min aggregation."""

Expand All @@ -111,7 +111,7 @@ def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
)


@PublicAPI(stability="beta")
@PublicAPI
class Max(_AggregateOnKeyBase):
"""Defines max aggregation."""

Expand All @@ -128,7 +128,7 @@ def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
)


@PublicAPI(stability="beta")
@PublicAPI
class Mean(_AggregateOnKeyBase):
"""Defines mean aggregation."""

Expand All @@ -149,7 +149,7 @@ def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
)


@PublicAPI(stability="beta")
@PublicAPI
class Std(_AggregateOnKeyBase):
"""Defines standard deviation aggregation.
Expand Down
50 changes: 20 additions & 30 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
_epoch_warned = False


@PublicAPI(stability="beta")
@PublicAPI
class Dataset(Generic[T]):
"""Implements a distributed Arrow dataset.
Expand Down Expand Up @@ -167,7 +167,7 @@ def map(
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""

fn = cache_wrapper(fn)
fn = cache_wrapper(fn, compute)
context = DatasetContext.get_current()

def transform(block: Block) -> Iterable[Block]:
Expand Down Expand Up @@ -240,7 +240,7 @@ def map_batches(
import pyarrow as pa
import pandas as pd

fn = cache_wrapper(fn)
fn = cache_wrapper(fn, compute)
context = DatasetContext.get_current()

def transform(block: Block) -> Iterable[Block]:
Expand Down Expand Up @@ -370,7 +370,7 @@ def flat_map(
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""

fn = cache_wrapper(fn)
fn = cache_wrapper(fn, compute)
context = DatasetContext.get_current()

def transform(block: Block) -> Iterable[Block]:
Expand Down Expand Up @@ -417,7 +417,7 @@ def filter(
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""

fn = cache_wrapper(fn)
fn = cache_wrapper(fn, compute)
context = DatasetContext.get_current()

def transform(block: Block) -> Iterable[Block]:
Expand Down Expand Up @@ -504,8 +504,6 @@ def random_shuffle(
*,
seed: Optional[int] = None,
num_blocks: Optional[int] = None,
_spread_resource_prefix: Optional[str] = None,
_move: bool = False, # TODO: deprecate.
) -> "Dataset[T]":
"""Randomly shuffle the elements of this dataset.
Expand Down Expand Up @@ -534,7 +532,7 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
if _move or clear_input_blocks:
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
Expand All @@ -545,7 +543,6 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks,
random_shuffle=True,
random_seed=seed,
_spread_resource_prefix=_spread_resource_prefix,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
Expand Down Expand Up @@ -2523,11 +2520,6 @@ def __iter__(self):
)
return pipe

def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]":
raise DeprecationWarning(
"Use .window(blocks_per_window=n) instead of " ".pipeline(parallelism=n)"
)

def window(
self,
*,
Expand Down Expand Up @@ -2670,21 +2662,6 @@ def __iter__(self):
)
return pipe

@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
"""Get a list of references to the underlying blocks of this dataset.
This function can be used for zero-copy access to the data. It blocks
until the underlying blocks are computed.
Time complexity: O(1)
Returns:
A list of references to this dataset's blocks.
"""
return self._plan.execute().get_blocks()

@DeveloperAPI
def fully_executed(self) -> "Dataset[T]":
"""Force full evaluation of the blocks of this dataset.
Expand All @@ -2709,11 +2686,24 @@ def fully_executed(self) -> "Dataset[T]":
ds._set_uuid(self._get_uuid())
return ds

@DeveloperAPI
def stats(self) -> str:
"""Returns a string containing execution timing information."""
return self._plan.stats().summary_string()

@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
"""Get a list of references to the underlying blocks of this dataset.
This function can be used for zero-copy access to the data. It blocks
until the underlying blocks are computed.
Time complexity: O(1)
Returns:
A list of references to this dataset's blocks.
"""
return self._plan.execute().get_blocks()

def _experimental_lazy(self) -> "Dataset[T]":
"""Enable lazy evaluation (experimental)."""
self._lazy = True
Expand Down
35 changes: 18 additions & 17 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,24 @@
logger = logging.getLogger(__name__)

# Operations that can be naively applied per dataset row in the pipeline.
PER_DATASET_OPS = ["map", "map_batches", "add_column", "flat_map", "filter"]
_PER_DATASET_OPS = ["map", "map_batches", "add_column", "flat_map", "filter"]

# Operations that apply to each dataset holistically in the pipeline.
HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"]
_HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"]

# Similar to above but we should force evaluation immediately.
PER_DATASET_OUTPUT_OPS = [
_PER_DATASET_OUTPUT_OPS = [
"write_json",
"write_csv",
"write_parquet",
"write_datasource",
]

# Operations that operate over the stream of output batches from the pipeline.
OUTPUT_ITER_OPS = ["take", "take_all", "show", "to_tf", "to_torch"]
_OUTPUT_ITER_OPS = ["take", "take_all", "show", "to_tf", "to_torch"]


@PublicAPI(stability="beta")
@PublicAPI
class DatasetPipeline(Generic[T]):
"""Implements a pipeline of Datasets.
Expand Down Expand Up @@ -475,19 +475,26 @@ def __iter__(self):
length=length,
)

def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
def schema(
self, fetch_if_missing: bool = False
) -> Union[type, "pyarrow.lib.Schema"]:
"""Return the schema of the dataset pipeline.
For datasets of Arrow records, this will return the Arrow schema.
For dataset of Python objects, this returns their Python type.
Time complexity: O(1)
Args:
fetch_if_missing: If True, synchronously fetch the schema if it's
not known. Default is False, where None is returned if the
schema is not known.
Returns:
The Python type or Arrow schema of the records, or None if the
schema is not known.
"""
return next(self.iter_datasets()).schema()
return next(self.iter_datasets()).schema(fetch_if_missing=fetch_if_missing)

def count(self) -> int:
"""Count the number of records in the dataset pipeline.
Expand Down Expand Up @@ -670,12 +677,6 @@ def foreach_window(
_executed=self._executed,
)

def foreach_dataset(self, *a, **kw) -> None:
raise DeprecationWarning(
"`foreach_dataset` has been renamed to `foreach_window`."
)

@DeveloperAPI
def stats(self, exclude_first_window: bool = True) -> str:
"""Returns a string containing execution timing information.
Expand Down Expand Up @@ -743,7 +744,7 @@ def _optimize_stages(self):
self._optimized_stages = optimized_stages


for method in PER_DATASET_OPS:
for method in _PER_DATASET_OPS:

def make_impl(method):
delegate = getattr(Dataset, method)
Expand All @@ -766,7 +767,7 @@ def impl(self, *args, **kwargs) -> "DatasetPipeline[U]":

setattr(DatasetPipeline, method, make_impl(method))

for method in HOLISTIC_PER_DATASET_OPS:
for method in _HOLISTIC_PER_DATASET_OPS:

def make_impl(method):
delegate = getattr(Dataset, method)
Expand Down Expand Up @@ -798,7 +799,7 @@ def impl(*a, **kw):
setattr(DatasetPipeline, method, deprecation_warning(method))
setattr(DatasetPipeline, method + "_each_window", make_impl(method))

for method in PER_DATASET_OUTPUT_OPS:
for method in _PER_DATASET_OUTPUT_OPS:

def make_impl(method):
delegate = getattr(Dataset, method)
Expand All @@ -822,7 +823,7 @@ def impl(self, *args, **kwargs):

setattr(DatasetPipeline, method, make_impl(method))

for method in OUTPUT_ITER_OPS:
for method in _OUTPUT_ITER_OPS:

def make_impl(method):
delegate = getattr(Dataset, method)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/grouped_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ray.data.block import Block, BlockAccessor, BlockMetadata, T, U, KeyType


@PublicAPI(stability="beta")
@PublicAPI
class GroupedDataset(Generic[T]):
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
Expand Down
Loading

0 comments on commit 52491c8

Please sign in to comment.