Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Support tensor columns in to_tf and to_torch. #24752

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 164 additions & 3 deletions doc/source/data/dataset-tensor-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>

Please note that the ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
.. note::

The ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.

Working with tensor column datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -143,6 +145,167 @@ This dataset can then be written to Parquet files. The tensor column schema will
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>

Converting to a Torch/TensorFlow Dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This dataset can also be converted to a Torch or TensorFlow dataset via the standard
:meth:`ds.to_torch() <ray.data.Dataset.to_torch>` and
:meth:`ds.to_tf() <ray.data.Dataset.to_tf>` APIs for ingestion into those respective ML
training frameworks. The tensor column will be automatically converted to a
Torch/TensorFlow tensor without incurring any copies.

.. note::

When converting to a TensorFlow Dataset, you will need to give the full tensor spec
for the tensor columns, including the shape of each underlying tensor element in said
column.


.. tabbed:: Torch

Convert a ``Dataset`` containing a single tensor feature column to a Torch ``IterableDataset``.

.. code-block:: python

import ray
import numpy as np
import pandas as pd
import torch

df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)

# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)

# A feature tensor and label tensor is yielded per batch.
for X, y in torch_ds:
# Train model(X, y)

.. tabbed:: TensorFlow

Convert a ``Dataset`` containing a single tensor feature column to a TensorFlow ``tf.data.Dataset``.

.. code-block:: python

import ray
import numpy as np
import pandas as pd
import tensorflow as tf

tensor_element_shape = (32, 32)

df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)

# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
output_signature=(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)

# A feature tensor and label tensor is yielded per batch.
for X, y in tf_ds:
# Train model(X, y)

If your columns have different types **OR** your (tensor) columns have different shapes,
these columns are incompatible and you will not be able to stack the column tensors
into a single tensor. Instead, you will need to group the columns by compatibility in
the ``feature_columns`` argument.

E.g., if columns ``"feature_1"`` and ``"feature_2"`` are incompatible, you should give
``to_torch()`` a ``feature_columns=[["feature_1"], ["feature_2"]]`` argument in order to
instruct it to return separate tensors for ``"feature_1"`` and ``"feature_2"``. For
``to_torch()``, if isolating single columns as in the ``"feature_1"`` + ``"feature_2"``
example, you may also want to provide ``unsqueeze_feature_tensors=False`` in order to
remove the redundant column dimension for each of the unit column tensors.

.. tabbed:: Torch

Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a Torch ``IterableDataset``.

.. code-block:: python

import ray
import numpy as np
import pandas as pd
import torch

df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)

# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)

# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in torch_ds:
# Train model((feature_1, feature_2), y)

.. tabbed:: TensorFlow

Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a TensorFlow ``tf.data.Dataset``.

.. code-block:: python

import ray
import numpy as np
import pandas as pd
import torch

tensor_element_shape = (32, 32)

df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)

# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
output_signature=(
(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None, 1), dtype=tf.int64),
),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)

# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in tf_ds:
# Train model((feature_1, feature_2), y)

End-to-end workflow with our Pandas extension type
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -246,5 +409,3 @@ This feature currently comes with a few known limitations that we are either act

* All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here <https://github.com/ray-project/ray/issues/18316>`__.
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
* Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue <https://github.com/pytorch/pytorch/issues/51156>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18314>`__.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 doc updates. Could we also add some explicit examples / text in the docs on how to do this?

Copy link
Contributor Author

@clarkzinzow clarkzinzow May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm document how to do what? This should now transparently work with tensor columns without any considerations by the user, so there wouldn't be anything special to document at exchange time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not clear to me, and someone only passably familiar with the tensor code, how exactly I would use it. Could we have a runnable example showing it end to end with a tensor dataset?

Basically whatever's in the unit test, but in a more friendly example form.

Copy link
Contributor Author

@clarkzinzow clarkzinzow May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on having an e2e example for tensor data being a good idea, and I'm about to work on creating an e2e example for tensor data as part of the Datasets GA docs work in a separate PR, but the point that I'm making here is that the user shouldn't have to do anything differently in the .to_torch() or .to_tf() call just because their data contains tensor columns.

The tensor column --> ML framework tensor conversion will happen automatically, without any tensor-column-specific args to .to_torch() or .to_tf(), or special considerations for the upstream tensor column creation other than what's already documented in the tensor column guide: creating the tensor column in the first place. For .to_tf(), the tensor spec of the columns needs to be given, but that is already a requirement for all columns.

Since this example work is already planned for before GA, could we merge this in order to unblock the user PoC that depends on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the example I'm asking for is more in length than this thread already. Let's set a high standard for documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed up an example so we can resolve this. I'll have to rewrite/delete this example early next week once other work is merged, which is one of the reasons that I didn't want to block the user PoC on this example, so I'd like it if we could be a bit more pragmatic in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I don't see any harm in blocking this PR on the e2e example, in fact. We should be treating docs like unit tests, and not a separate "optional" component. There shouldn't really be a separate e2e example effort, for example.

I left a separate comment on making the example here more fleshed out (maybe this is copying work from your other effort).

* Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR <https://github.com/pandas-dev/pandas/pull/43160>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18315>`__.
31 changes: 25 additions & 6 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,27 @@ def to_tf(
if isinstance(output_signature, list):
output_signature = tuple(output_signature)

def get_df_values(df: "pandas.DataFrame") -> np.ndarray:
# TODO(Clark): Support unsqueezing column dimension API, similar to
# to_torch().
try:
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
values = df.values
except ValueError as e:
import pandas as pd

# Pandas DataFrame.values doesn't support extension arrays in all
# supported Pandas versions, so we check to see if this DataFrame
# contains any extensions arrays and do a manual conversion if so.
# See https://github.com/pandas-dev/pandas/pull/43160.
if any(
isinstance(dtype, pd.api.extensions.ExtensionDtype)
for dtype in df.dtypes
):
values = np.stack([col.to_numpy() for _, col in df.items()], axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm do we want to conver all columns to numpy if any column is an ExtensionDtype.

Or should we only do col.to_numpy() only if the dtype of that particular columns is an ExtensionDtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter, df.values converts all columns to NumPy ndarrays anyway so it would be exactly the same.

else:
raise e from None
return values

def make_generator():
for batch in self.iter_batches(
prefetch_blocks=prefetch_blocks,
Expand All @@ -2392,13 +2413,13 @@ def make_generator():

features = None
if feature_columns is None:
features = batch.values
features = get_df_values(batch)
elif isinstance(feature_columns, list):
if all(isinstance(column, str) for column in feature_columns):
features = batch[feature_columns].values
features = get_df_values(batch[feature_columns])
elif all(isinstance(columns, list) for columns in feature_columns):
features = tuple(
batch[columns].values for columns in feature_columns
get_df_values(batch[columns]) for columns in feature_columns
)
else:
raise ValueError(
Expand All @@ -2407,7 +2428,7 @@ def make_generator():
)
elif isinstance(feature_columns, dict):
features = {
key: batch[columns].values
key: get_df_values(batch[columns])
for key, columns in feature_columns.items()
}
else:
Expand All @@ -2416,8 +2437,6 @@ def make_generator():
f"but got a `{type(feature_columns).__name__}` instead."
)

# TODO(Clark): Support batches containing our extension array
# TensorArray.
if label_column:
yield features, targets
else:
Expand Down
Loading