Skip to content

Commit

Permalink
[Datasets] Support tensor columns in to_tf and to_torch. (ray-pro…
Browse files Browse the repository at this point in the history
…ject#24752)

This PR adds support for tensor columns in the to_tf() and to_torch() APIs.

For Torch, this involves an explicit extension array check and (zero-copy) conversion of the tensor column to a NumPy array before converting the column to a Torch tensor.

For TensorFlow, this involves bypassing df.values when converting tensor feature columns to NumPy arrays, instead manually creating a single NumPy array from the column Series.

In both cases, I think that the UX around heterogeneous feature columns and squeezing the column dimension could be improved, but I'm saving that for a future PR.
  • Loading branch information
clarkzinzow committed May 20, 2022
1 parent d589329 commit a930b2f
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 45 deletions.
167 changes: 164 additions & 3 deletions doc/source/data/dataset-tensor-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>
Please note that the ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
.. note::

The ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.

Working with tensor column datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -143,6 +145,167 @@ This dataset can then be written to Parquet files. The tensor column schema will
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>
Converting to a Torch/TensorFlow Dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This dataset can also be converted to a Torch or TensorFlow dataset via the standard
:meth:`ds.to_torch() <ray.data.Dataset.to_torch>` and
:meth:`ds.to_tf() <ray.data.Dataset.to_tf>` APIs for ingestion into those respective ML
training frameworks. The tensor column will be automatically converted to a
Torch/TensorFlow tensor without incurring any copies.

.. note::

When converting to a TensorFlow Dataset, you will need to give the full tensor spec
for the tensor columns, including the shape of each underlying tensor element in said
column.


.. tabbed:: Torch

Convert a ``Dataset`` containing a single tensor feature column to a Torch ``IterableDataset``.

.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)
# A feature tensor and label tensor is yielded per batch.
for X, y in torch_ds:
# Train model(X, y)
.. tabbed:: TensorFlow

Convert a ``Dataset`` containing a single tensor feature column to a TensorFlow ``tf.data.Dataset``.

.. code-block:: python
import ray
import numpy as np
import pandas as pd
import tensorflow as tf
tensor_element_shape = (32, 32)
df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
output_signature=(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
# A feature tensor and label tensor is yielded per batch.
for X, y in tf_ds:
# Train model(X, y)
If your columns have different types **OR** your (tensor) columns have different shapes,
these columns are incompatible and you will not be able to stack the column tensors
into a single tensor. Instead, you will need to group the columns by compatibility in
the ``feature_columns`` argument.

E.g., if columns ``"feature_1"`` and ``"feature_2"`` are incompatible, you should give
``to_torch()`` a ``feature_columns=[["feature_1"], ["feature_2"]]`` argument in order to
instruct it to return separate tensors for ``"feature_1"`` and ``"feature_2"``. For
``to_torch()``, if isolating single columns as in the ``"feature_1"`` + ``"feature_2"``
example, you may also want to provide ``unsqueeze_feature_tensors=False`` in order to
remove the redundant column dimension for each of the unit column tensors.

.. tabbed:: Torch

Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a Torch ``IterableDataset``.

.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)
# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in torch_ds:
# Train model((feature_1, feature_2), y)
.. tabbed:: TensorFlow

Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a TensorFlow ``tf.data.Dataset``.

.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
tensor_element_shape = (32, 32)
df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
output_signature=(
(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None, 1), dtype=tf.int64),
),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in tf_ds:
# Train model((feature_1, feature_2), y)
End-to-end workflow with our Pandas extension type
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -246,5 +409,3 @@ This feature currently comes with a few known limitations that we are either act

* All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here <https://github.com/ray-project/ray/issues/18316>`__.
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
* Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue <https://github.com/pytorch/pytorch/issues/51156>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18314>`__.
* Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR <https://github.com/pandas-dev/pandas/pull/43160>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18315>`__.
31 changes: 25 additions & 6 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,27 @@ def to_tf(
if isinstance(output_signature, list):
output_signature = tuple(output_signature)

def get_df_values(df: "pandas.DataFrame") -> np.ndarray:
# TODO(Clark): Support unsqueezing column dimension API, similar to
# to_torch().
try:
values = df.values
except ValueError as e:
import pandas as pd

# Pandas DataFrame.values doesn't support extension arrays in all
# supported Pandas versions, so we check to see if this DataFrame
# contains any extensions arrays and do a manual conversion if so.
# See https://github.com/pandas-dev/pandas/pull/43160.
if any(
isinstance(dtype, pd.api.extensions.ExtensionDtype)
for dtype in df.dtypes
):
values = np.stack([col.to_numpy() for _, col in df.items()], axis=1)
else:
raise e from None
return values

def make_generator():
for batch in self.iter_batches(
prefetch_blocks=prefetch_blocks,
Expand All @@ -2392,13 +2413,13 @@ def make_generator():

features = None
if feature_columns is None:
features = batch.values
features = get_df_values(batch)
elif isinstance(feature_columns, list):
if all(isinstance(column, str) for column in feature_columns):
features = batch[feature_columns].values
features = get_df_values(batch[feature_columns])
elif all(isinstance(columns, list) for columns in feature_columns):
features = tuple(
batch[columns].values for columns in feature_columns
get_df_values(batch[columns]) for columns in feature_columns
)
else:
raise ValueError(
Expand All @@ -2407,7 +2428,7 @@ def make_generator():
)
elif isinstance(feature_columns, dict):
features = {
key: batch[columns].values
key: get_df_values(batch[columns])
for key, columns in feature_columns.items()
}
else:
Expand All @@ -2416,8 +2437,6 @@ def make_generator():
f"but got a `{type(feature_columns).__name__}` instead."
)

# TODO(Clark): Support batches containing our extension array
# TensorArray.
if label_column:
yield features, targets
else:
Expand Down
Loading

0 comments on commit a930b2f

Please sign in to comment.