From 52504d51f114439b15bcfce69aeb2649d50cbaa9 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Thu, 12 May 2022 23:34:36 +0000 Subject: [PATCH 1/3] Support tensor columns in to_tf and to_torch. --- doc/source/data/dataset-tensor-support.rst | 2 - python/ray/data/dataset.py | 31 ++++- python/ray/data/tests/test_dataset.py | 141 ++++++++++++++++----- python/ray/ml/utils/torch_utils.py | 12 +- 4 files changed, 147 insertions(+), 39 deletions(-) diff --git a/doc/source/data/dataset-tensor-support.rst b/doc/source/data/dataset-tensor-support.rst index 143f7dec8d30..df1b0a8fd6c6 100644 --- a/doc/source/data/dataset-tensor-support.rst +++ b/doc/source/data/dataset-tensor-support.rst @@ -246,5 +246,3 @@ This feature currently comes with a few known limitations that we are either act * All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here `__. * Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue `__. An explicit ``tensor_column_schema`` parameter has been added for :func:`read_parquet() ` as a stopgap solution. - * Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue `__. Workarounds are being `investigated `__. - * Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR `__. Workarounds are being `investigated `__. diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 5f47407848e2..5560d71c5138 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2380,6 +2380,27 @@ def to_tf( if isinstance(output_signature, list): output_signature = tuple(output_signature) + def get_df_values(df: "pandas.DataFrame") -> np.ndarray: + # TODO(Clark): Support unsqueezing column dimension API, simialr to + # to_torch(). + try: + values = df.values + except ValueError as e: + import pandas as pd + + # Pandas DataFrame.values doesn't support extension arrays in all + # supported Pandas versions, so we check to see if this DataFrame + # contains any extensions arrays and do a manual conversion if so. + # See https://github.com/pandas-dev/pandas/pull/43160. + if any( + isinstance(dtype, pd.api.extensions.ExtensionDtype) + for dtype in df.dtypes + ): + values = np.stack([col.to_numpy() for _, col in df.items()], axis=1) + else: + raise e from None + return values + def make_generator(): for batch in self.iter_batches( prefetch_blocks=prefetch_blocks, @@ -2392,13 +2413,13 @@ def make_generator(): features = None if feature_columns is None: - features = batch.values + features = get_df_values(batch) elif isinstance(feature_columns, list): if all(isinstance(column, str) for column in feature_columns): - features = batch[feature_columns].values + features = get_df_values(batch[feature_columns]) elif all(isinstance(columns, list) for columns in feature_columns): features = tuple( - batch[columns].values for columns in feature_columns + get_df_values(batch[columns]) for columns in feature_columns ) else: raise ValueError( @@ -2407,7 +2428,7 @@ def make_generator(): ) elif isinstance(feature_columns, dict): features = { - key: batch[columns].values + key: get_df_values(batch[columns]) for key, columns in feature_columns.items() } else: @@ -2416,8 +2437,6 @@ def make_generator(): f"but got a `{type(feature_columns).__name__}` instead." ) - # TODO(Clark): Support batches containing our extension array - # TensorArray. if label_column: yield features, targets else: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index a942becdf92f..56055c4f9997 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -950,49 +950,83 @@ def test_tensors_in_tables_parquet_bytes_with_schema( np.testing.assert_equal(v, e) -@pytest.mark.skip( - reason=( - "Waiting for pytorch to support tensor creation from objects that " - "implement the __array__ interface. See " - "https://github.com/pytorch/pytorch/issues/51156" - ) -) @pytest.mark.parametrize("pipelined", [False, True]) def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined): - import torch - outer_dim = 3 inner_shape = (2, 2, 2) shape = (outer_dim,) + inner_shape num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape) df1 = pd.DataFrame( - {"one": [1, 2, 3], "two": TensorArray(arr), "label": [1.0, 2.0, 3.0]} + {"one": TensorArray(arr), "two": TensorArray(arr + 1), "label": [1.0, 2.0, 3.0]} ) arr2 = np.arange(num_items, 2 * num_items).reshape(shape) df2 = pd.DataFrame( - {"one": [4, 5, 6], "two": TensorArray(arr2), "label": [4.0, 5.0, 6.0]} + { + "one": TensorArray(arr2), + "two": TensorArray(arr2 + 1), + "label": [4.0, 5.0, 6.0], + } ) df = pd.concat([df1, df2]) ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) torchd = ds.to_torch(label_column="label", batch_size=2) - num_epochs = 2 + num_epochs = 1 if pipelined else 2 for _ in range(num_epochs): iterations = [] for batch in iter(torchd): - iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy()) + iterations.append(batch[0].numpy()) combined_iterations = np.concatenate(iterations) - assert np.array_equal(np.sort(df.values), np.sort(combined_iterations)) + values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1) + np.testing.assert_array_equal(np.sort(values), np.sort(combined_iterations)) -@pytest.mark.skip( - reason=( - "Waiting for Pandas DataFrame.values for extension arrays fix to be " - "released. See https://github.com/pandas-dev/pandas/pull/43160" +@pytest.mark.parametrize("pipelined", [False, True]) +def test_tensors_in_tables_to_torch_mix(ray_start_regular_shared, pipelined): + outer_dim = 3 + inner_shape = (2, 2, 2) + shape = (outer_dim,) + inner_shape + num_items = np.prod(np.array(shape)) + arr = np.arange(num_items).reshape(shape) + df1 = pd.DataFrame( + { + "one": TensorArray(arr), + "two": [1, 2, 3], + "label": [1.0, 2.0, 3.0], + } ) -) + arr2 = np.arange(num_items, 2 * num_items).reshape(shape) + df2 = pd.DataFrame( + { + "one": TensorArray(arr2), + "two": [4, 5, 6], + "label": [4.0, 5.0, 6.0], + } + ) + df = pd.concat([df1, df2]) + ds = ray.data.from_pandas([df1, df2]) + ds = maybe_pipeline(ds, pipelined) + torchd = ds.to_torch( + label_column="label", + feature_columns=[["one"], ["two"]], + batch_size=2, + unsqueeze_feature_tensors=False, + ) + + num_epochs = 1 if pipelined else 2 + for _ in range(num_epochs): + col1, col2 = [], [] + for batch in iter(torchd): + features = batch[0] + col1.append(features[0].numpy()) + col2.append(features[1].numpy()) + col1, col2 = np.concatenate(col1), np.concatenate(col2) + np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy())) + np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy())) + + @pytest.mark.parametrize("pipelined", [False, True]) def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): import tensorflow as tf @@ -1002,21 +1036,19 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): shape = (outer_dim,) + inner_shape num_items = np.prod(np.array(shape)) arr = np.arange(num_items).reshape(shape).astype(np.float) - # TODO(Clark): Ensure that heterogeneous columns is properly supported - # (tf.RaggedTensorSpec) df1 = pd.DataFrame( { "one": TensorArray(arr), - "two": TensorArray(arr), - "label": TensorArray(arr), + "two": TensorArray(arr + 1), + "label": [1, 2, 3], } ) arr2 = np.arange(num_items, 2 * num_items).reshape(shape).astype(np.float) df2 = pd.DataFrame( { "one": TensorArray(arr2), - "two": TensorArray(arr2), - "label": TensorArray(arr2), + "two": TensorArray(arr2 + 1), + "label": [4, 5, 6], } ) df = pd.concat([df1, df2]) @@ -1026,15 +1058,66 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): label_column="label", output_signature=( tf.TensorSpec(shape=(None, 2, 2, 2, 2), dtype=tf.float32), - tf.TensorSpec(shape=(None, 1, 2, 2, 2), dtype=tf.float32), + tf.TensorSpec(shape=(None,), dtype=tf.float32), ), + batch_size=2, ) iterations = [] for batch in tfd.as_numpy_iterator(): - iterations.append(np.concatenate((batch[0], batch[1]), axis=1)) + iterations.append(batch[0]) combined_iterations = np.concatenate(iterations) - arr = np.array([[np.asarray(v) for v in values] for values in df.to_numpy()]) - np.testing.assert_array_equal(arr, combined_iterations) + values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1) + np.testing.assert_array_equal(values, combined_iterations) + + +@pytest.mark.parametrize("pipelined", [False, True]) +def test_tensors_in_tables_to_tf_mix(ray_start_regular_shared, pipelined): + import tensorflow as tf + + outer_dim = 3 + inner_shape = (2, 2, 2) + shape = (outer_dim,) + inner_shape + num_items = np.prod(np.array(shape)) + arr = np.arange(num_items).reshape(shape).astype(np.float) + df1 = pd.DataFrame( + { + "one": TensorArray(arr), + "two": [1, 2, 3], + "label": [1.0, 2.0, 3.0], + } + ) + arr2 = np.arange(num_items, 2 * num_items).reshape(shape).astype(np.float) + df2 = pd.DataFrame( + { + "one": TensorArray(arr2), + "two": [4, 5, 6], + "label": [4.0, 5.0, 6.0], + } + ) + df = pd.concat([df1, df2]) + ds = ray.data.from_pandas([df1, df2]) + ds = maybe_pipeline(ds, pipelined) + tfd = ds.to_tf( + label_column="label", + feature_columns=[["one"], ["two"]], + output_signature=( + ( + tf.TensorSpec(shape=(None, 1, 2, 2, 2), dtype=tf.float32), + tf.TensorSpec(shape=(None, 1), dtype=tf.float32), + ), + tf.TensorSpec(shape=(None,), dtype=tf.float32), + ), + batch_size=2, + ) + col1, col2 = [], [] + for batch in tfd.as_numpy_iterator(): + features = batch[0] + col1.append(features[0]) + col2.append(features[1]) + col1 = np.squeeze(np.concatenate(col1), axis=1) + col2 = np.squeeze(np.concatenate(col2), axis=1) + np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy())) + np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy())) def test_empty_shuffle(ray_start_regular_shared): diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index 963fca711109..5e1f88687429 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -54,6 +54,12 @@ def convert_pandas_to_torch_tensor( def tensorize(vals, dtype): """This recursive function allows to convert pyarrow List dtypes to multi-dimensional tensors.""" + if isinstance(vals, pd.api.extensions.ExtensionArray): + # torch.as_tensor() does not yet support the __array__ protocol, so we need + # to convert extension arrays to ndarrays manually before converting to a + # Torch tensor. + # See https://github.com/pytorch/pytorch/issues/51156. + vals = vals.to_numpy() try: return torch.as_tensor(vals, dtype=dtype) except TypeError: @@ -79,8 +85,10 @@ def get_tensor_for_columns(columns, dtype): feature_tensors.append(t) if len(feature_tensors) > 1: - return torch.cat(feature_tensors, dim=1) - return feature_tensors[0] + feature_tensor = torch.cat(feature_tensors, dim=1) + else: + feature_tensor = feature_tensors[0] + return feature_tensor if multi_input: if type(column_dtypes) not in [list, tuple]: From b37e88764dd68d12e6f0ad49a3034b94d637d92d Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Fri, 13 May 2022 00:34:36 +0000 Subject: [PATCH 2/3] PR feedback: comment typo, add back labels check. --- python/ray/data/dataset.py | 2 +- python/ray/data/tests/test_dataset.py | 45 +++++++++++++++++---------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 5560d71c5138..614a5af3c6dc 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2381,7 +2381,7 @@ def to_tf( output_signature = tuple(output_signature) def get_df_values(df: "pandas.DataFrame") -> np.ndarray: - # TODO(Clark): Support unsqueezing column dimension API, simialr to + # TODO(Clark): Support unsqueezing column dimension API, similar to # to_torch(). try: values = df.values diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 56055c4f9997..9c31c769b28f 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -971,16 +971,20 @@ def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined): df = pd.concat([df1, df2]) ds = ray.data.from_pandas([df1, df2]) ds = maybe_pipeline(ds, pipelined) - torchd = ds.to_torch(label_column="label", batch_size=2) + torchd = ds.to_torch( + label_column="label", batch_size=2, unsqueeze_label_tensor=False + ) num_epochs = 1 if pipelined else 2 for _ in range(num_epochs): - iterations = [] + features, labels = [], [] for batch in iter(torchd): - iterations.append(batch[0].numpy()) - combined_iterations = np.concatenate(iterations) + features.append(batch[0].numpy()) + labels.append(batch[1].numpy()) + features, labels = np.concatenate(features), np.concatenate(labels) values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1) - np.testing.assert_array_equal(np.sort(values), np.sort(combined_iterations)) + np.testing.assert_array_equal(values, features) + np.testing.assert_array_equal(df["label"].to_numpy(), labels) @pytest.mark.parametrize("pipelined", [False, True]) @@ -1012,19 +1016,22 @@ def test_tensors_in_tables_to_torch_mix(ray_start_regular_shared, pipelined): label_column="label", feature_columns=[["one"], ["two"]], batch_size=2, + unsqueeze_label_tensor=False, unsqueeze_feature_tensors=False, ) num_epochs = 1 if pipelined else 2 for _ in range(num_epochs): - col1, col2 = [], [] + col1, col2, labels = [], [], [] for batch in iter(torchd): - features = batch[0] - col1.append(features[0].numpy()) - col2.append(features[1].numpy()) + col1.append(batch[0][0].numpy()) + col2.append(batch[0][1].numpy()) + labels.append(batch[1].numpy()) col1, col2 = np.concatenate(col1), np.concatenate(col2) + labels = np.concatenate(labels) np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy())) np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy())) + np.testing.assert_array_equal(labels, np.sort(df["label"].to_numpy())) @pytest.mark.parametrize("pipelined", [False, True]) @@ -1062,12 +1069,14 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined): ), batch_size=2, ) - iterations = [] + features, labels = [], [] for batch in tfd.as_numpy_iterator(): - iterations.append(batch[0]) - combined_iterations = np.concatenate(iterations) + features.append(batch[0]) + labels.append(batch[1]) + features, labels = np.concatenate(features), np.concatenate(labels) values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1) - np.testing.assert_array_equal(values, combined_iterations) + np.testing.assert_array_equal(values, features) + np.testing.assert_array_equal(df["label"].to_numpy(), labels) @pytest.mark.parametrize("pipelined", [False, True]) @@ -1109,15 +1118,17 @@ def test_tensors_in_tables_to_tf_mix(ray_start_regular_shared, pipelined): ), batch_size=2, ) - col1, col2 = [], [] + col1, col2, labels = [], [], [] for batch in tfd.as_numpy_iterator(): - features = batch[0] - col1.append(features[0]) - col2.append(features[1]) + col1.append(batch[0][0]) + col2.append(batch[0][1]) + labels.append(batch[1]) col1 = np.squeeze(np.concatenate(col1), axis=1) col2 = np.squeeze(np.concatenate(col2), axis=1) + labels = np.concatenate(labels) np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy())) np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy())) + np.testing.assert_array_equal(labels, np.sort(df["label"].to_numpy())) def test_empty_shuffle(ray_start_regular_shared): From baa360800422e2dd6be8b3a18de14d6f412f8f0f Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Fri, 13 May 2022 21:55:51 +0000 Subject: [PATCH 3/3] Add to_torch()/to_tf() example to docs. --- doc/source/data/dataset-tensor-support.rst | 165 ++++++++++++++++++++- 1 file changed, 164 insertions(+), 1 deletion(-) diff --git a/doc/source/data/dataset-tensor-support.rst b/doc/source/data/dataset-tensor-support.rst index df1b0a8fd6c6..51be49a0047e 100644 --- a/doc/source/data/dataset-tensor-support.rst +++ b/doc/source/data/dataset-tensor-support.rst @@ -107,7 +107,9 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored # -> one: int64 # two: extension> -Please note that the ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions. +.. note:: + + The ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions. Working with tensor column datasets ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -143,6 +145,167 @@ This dataset can then be written to Parquet files. The tensor column schema will # -> one: int64 # two: extension> +Converting to a Torch/TensorFlow Dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This dataset can also be converted to a Torch or TensorFlow dataset via the standard +:meth:`ds.to_torch() ` and +:meth:`ds.to_tf() ` APIs for ingestion into those respective ML +training frameworks. The tensor column will be automatically converted to a +Torch/TensorFlow tensor without incurring any copies. + +.. note:: + + When converting to a TensorFlow Dataset, you will need to give the full tensor spec + for the tensor columns, including the shape of each underlying tensor element in said + column. + + +.. tabbed:: Torch + + Convert a ``Dataset`` containing a single tensor feature column to a Torch ``IterableDataset``. + + .. code-block:: python + + import ray + import numpy as np + import pandas as pd + import torch + + df = pd.DataFrame({ + "feature": TensorArray(np.arange(4096).reshape((4, 32, 32))), + "label": [1, 2, 3, 4], + }) + ds = ray.data.from_pandas(df) + + # Convert the dataset to a Torch IterableDataset. + torch_ds = ds.to_torch( + label_column="label", + batch_size=2, + unsqueeze_label_tensor=False, + unsqueeze_feature_tensors=False, + ) + + # A feature tensor and label tensor is yielded per batch. + for X, y in torch_ds: + # Train model(X, y) + +.. tabbed:: TensorFlow + + Convert a ``Dataset`` containing a single tensor feature column to a TensorFlow ``tf.data.Dataset``. + + .. code-block:: python + + import ray + import numpy as np + import pandas as pd + import tensorflow as tf + + tensor_element_shape = (32, 32) + + df = pd.DataFrame({ + "feature": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)), + "label": [1, 2, 3, 4], + }) + ds = ray.data.from_pandas(df) + + # Convert the dataset to a TensorFlow Dataset. + tf_ds = ds.to_tf( + label_column="label", + output_signature=( + tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32), + tf.TensorSpec(shape=(None,), dtype=tf.float32), + ), + batch_size=2, + ) + + # A feature tensor and label tensor is yielded per batch. + for X, y in tf_ds: + # Train model(X, y) + +If your columns have different types **OR** your (tensor) columns have different shapes, +these columns are incompatible and you will not be able to stack the column tensors +into a single tensor. Instead, you will need to group the columns by compatibility in +the ``feature_columns`` argument. + +E.g., if columns ``"feature_1"`` and ``"feature_2"`` are incompatible, you should give +``to_torch()`` a ``feature_columns=[["feature_1"], ["feature_2"]]`` argument in order to +instruct it to return separate tensors for ``"feature_1"`` and ``"feature_2"``. For +``to_torch()``, if isolating single columns as in the ``"feature_1"`` + ``"feature_2"`` +example, you may also want to provide ``unsqueeze_feature_tensors=False`` in order to +remove the redundant column dimension for each of the unit column tensors. + +.. tabbed:: Torch + + Convert a ``Dataset`` containing a tensor feature column and a scalar feature column + to a Torch ``IterableDataset``. + + .. code-block:: python + + import ray + import numpy as np + import pandas as pd + import torch + + df = pd.DataFrame({ + "feature_1": TensorArray(np.arange(4096).reshape((4, 32, 32))), + "feature_2": [5, 6, 7, 8], + "label": [1, 2, 3, 4], + }) + ds = ray.data.from_pandas(df) + + # Convert the dataset to a Torch IterableDataset. + torch_ds = ds.to_torch( + label_column="label", + feature_columns=[["feature_1"], ["feature_2"]], + batch_size=2, + unsqueeze_label_tensor=False, + unsqueeze_feature_tensors=False, + ) + + # Two feature tensors and one label tensor is yielded per batch. + for (feature_1, feature_2), y in torch_ds: + # Train model((feature_1, feature_2), y) + +.. tabbed:: TensorFlow + + Convert a ``Dataset`` containing a tensor feature column and a scalar feature column + to a TensorFlow ``tf.data.Dataset``. + + .. code-block:: python + + import ray + import numpy as np + import pandas as pd + import torch + + tensor_element_shape = (32, 32) + + df = pd.DataFrame({ + "feature_1": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)), + "feature_2": [5, 6, 7, 8], + "label": [1, 2, 3, 4], + }) + ds = ray.data.from_pandas(df) + + # Convert the dataset to a TensorFlow Dataset. + tf_ds = ds.to_tf( + label_column="label", + feature_columns=[["feature_1"], ["feature_2"]], + output_signature=( + ( + tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32), + tf.TensorSpec(shape=(None, 1), dtype=tf.int64), + ), + tf.TensorSpec(shape=(None,), dtype=tf.float32), + ), + batch_size=2, + ) + + # Two feature tensors and one label tensor is yielded per batch. + for (feature_1, feature_2), y in tf_ds: + # Train model((feature_1, feature_2), y) + End-to-end workflow with our Pandas extension type ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~