Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Simplify to_tf interface #29028

Merged
merged 16 commits into from
Oct 7, 2022
Merged
135 changes: 40 additions & 95 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,12 @@
from typing import Optional, Union, Dict
from typing import Dict, List, Optional, Union, Tuple

import numpy as np
import pandas as pd
import pyarrow
import tensorflow as tf
from pandas.api.types import is_object_dtype

from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed


def convert_pandas_to_tf_tensor(
df: pd.DataFrame, dtype: Optional[tf.dtypes.DType] = None
) -> tf.Tensor:
"""Convert a pandas dataframe to a TensorFlow tensor.

This function works in two steps:
1. Convert each dataframe column to a tensor.
2. Concatenate the resulting tensors along the last axis.

Arguments:
df: The dataframe to convert to a TensorFlow tensor. Columns must be of
a numeric dtype, ``TensorDtype``, or object dtype. If a column has
an object dtype, the column must contain ``ndarray`` objects.
dtype: Optional data type for the returned tensor. If a dtype isn't
provided, the dtype is inferred from ``df``.

Returns:
A tensor constructed from the dataframe.

Examples:
>>> import pandas as pd
>>> from ray.air._internal.tensorflow_utils import convert_pandas_to_tf_tensor
>>>
>>> df = pd.DataFrame({"X1": [1, 2, 3], "X2": [4, 5, 6]})
>>> convert_pandas_to_tf_tensor(df[["X1"]]).shape
TensorShape([3, 1])
>>> convert_pandas_to_tf_tensor(df[["X1", "X2"]]).shape
TensorShape([3, 2])

>>> from ray.data.extensions import TensorArray
>>> import numpy as np
>>>
>>> df = pd.DataFrame({"image": TensorArray(np.zeros((4, 3, 32, 32)))})
>>> convert_pandas_to_tf_tensor(df).shape
TensorShape([4, 3, 32, 32])
"""
if dtype is None:
try:
# We need to cast the tensors to a common type so that we can concatenate
# them. If the columns contain different types (for example, `float32`s
# and `int32`s), then `tf.concat` raises an error.
dtype: np.dtype = np.find_common_type(df.dtypes, [])

# if the columns are `ray.data.extensions.tensor_extension.TensorArray`,
# the dtype will be `object`. In this case, we need to set the dtype to
# none, and use the automatic type casting of `tf.convert_to_tensor`.
if is_object_dtype(dtype):
dtype = None

except TypeError:
# `find_common_type` fails if a series has `TensorDtype`. In this case,
# don't cast any of the series and continue.
pass

def tensorize(series):
try:
return tf.convert_to_tensor(series, dtype=dtype)
except ValueError:
# This exception will be raised if series is of object dtype or otherwise
# cannot be made into a tensor directly. We assume it's a sequence in that
# case. This is more robust than checking for dtype.
tensors = [tensorize(element) for element in series]
try:
return tf.stack(tensors)
except Exception:
# Try to coerce the tensor to a ragged tensor, if possible.
# If this fails, the exception will be propagated up to the caller.
return tf.ragged.stack(tensors)

tensors = []
for column in df.columns:
series = df[column]
try:
tensor = tensorize(series)
except Exception:
raise ValueError(
f"Failed to convert column {column} to a TensorFlow Tensor of dtype "
f"{dtype}. See above exception chain for the exact failure."
)
tensors.append(tensor)

if len(tensors) > 1:
tensors = [tf.expand_dims(tensor, axis=1) for tensor in tensors]

concatenated_tensor = tf.concat(tensors, axis=1)

if concatenated_tensor.shape.ndims == 1:
return tf.expand_dims(concatenated_tensor, axis=1)

return concatenated_tensor


def convert_ndarray_to_tf_tensor(
ndarray: np.ndarray,
dtype: Optional[tf.dtypes.DType] = None,
Expand Down Expand Up @@ -152,3 +59,41 @@ def convert_ndarray_batch_to_tf_tensor_batch(
}

return batch


def get_type_spec(
schema: "pyarrow.lib.Schema",
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
columns: Union[str, List[str]],
) -> Union[tf.TypeSpec, Dict[str, tf.TypeSpec]]:
import pyarrow as pa
from ray.data.extensions import TensorDtype, ArrowTensorType

assert not isinstance(schema, type)

dtypes: Dict[str, Union[np.dtype, pa.DataType]] = {
name: dtype for name, dtype in zip(schema.names, schema.types)
}
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

def get_dtype(dtype: Union[np.dtype, pa.DataType]) -> tf.dtypes.DType:
if isinstance(dtype, pa.DataType):
dtype = dtype.to_pandas_dtype()
if isinstance(dtype, TensorDtype):
dtype = dtype.element_dtype
return tf.dtypes.as_dtype(dtype)

def get_shape(dtype: Union[np.dtype, pa.DataType]) -> Tuple[int, ...]:
if isinstance(dtype, ArrowTensorType):
dtype = dtype.to_pandas_dtype()
if isinstance(dtype, TensorDtype):
return (None,) + dtype.element_shape
return (None,)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(columns, str):
name, dtype = columns, dtypes[columns]
return tf.TensorSpec(get_shape(dtype), dtype=get_dtype(dtype), name=name)

return {
name: tf.TensorSpec(get_shape(dtype), dtype=get_dtype(dtype), name=name)
for name, dtype in dtypes.items()
if name in columns
}
147 changes: 54 additions & 93 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2787,14 +2787,9 @@ def make_generator():

def to_tf(
self,
feature_columns: Union[str, List[str]],
label_columns: Union[str, List[str]],
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
*,
output_signature: Union[
TensorflowFeatureTypeSpec, Tuple[TensorflowFeatureTypeSpec, "tf.TypeSpec"]
],
label_column: Optional[str] = None,
feature_columns: Optional[
Union[List[str], List[List[str]], Dict[str, List[str]]]
] = None,
prefetch_blocks: int = 0,
batch_size: int = 1,
drop_last: bool = False,
Expand All @@ -2807,49 +2802,18 @@ def to_tf(
``iter_batches`` method. ``prefetch_blocks`` and ``batch_size``
arguments will be passed to that method.

For the features tensor (N is the ``batch_size`` and n1, ..., nk
are the number of features per tensor):

* If ``feature_columns`` is a ``List[str]``, the features will be
a tensor of shape (N, n), with columns corresponding to
``feature_columns``

* If ``feature_columns`` is a ``List[List[str]]``, the features will be
a list of tensors of shape [(N, n1),...,(N, nk)], with columns of each
tensor corresponding to the elements of ``feature_columns``

* If ``feature_columns`` is a ``Dict[str, List[str]]``, the features
will be a dict of key-tensor pairs of shape
{key1: (N, n1),..., keyN: (N, nk)}, with columns of each
tensor corresponding to the value of ``feature_columns`` under the
key.

This is only supported for datasets convertible to Arrow records.

Requires all datasets to have the same columns.

It is recommended to call ``.split()`` on this dataset if
there are to be multiple TensorFlow workers consuming the data.

The elements generated must be compatible with the given
``output_signature`` argument (same as in
``tf.data.Dataset.from_generator``).

Time complexity: O(1)

Args:
output_signature: If ``label_column`` is specified,
a two-element tuple containing a ``FeatureTypeSpec`` and
``tf.TypeSpec`` object corresponding to (features, label). Otherwise, a
single ``TensorflowFeatureTypeSpec`` corresponding to features tensor.
A ``TensorflowFeatureTypeSpec`` is a ``tf.TypeSpec``,
``List["tf.TypeSpec"]``, or ``Dict[str, "tf.TypeSpec"]``.
label_column: The name of the column used as the label
(second element of the output tuple). If not specified, output
will be just one tensor instead of a tuple.
feature_columns: The names of the columns to use as the features. Can be a
list of lists or a dict of string-list pairs for multi-tensor output.
If None, then use all columns except the label columns as the features.
feature_columns: Columns that correspond to inputs. If this is a string,
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
the input data is a tensor. If this is a list, the input data is a
``dict`` that maps column names to their tensor representation.
label_column: Columns that correspond to targets. If this is a string,
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
the target data is a tensor. If this is a list, the target data is a
``dict`` that maps column names to their tensor representation.
prefetch_blocks: The number of blocks to prefetch ahead of the
current block during the scan.
batch_size: Record batch size. Defaults to 1.
Expand All @@ -2868,75 +2832,72 @@ def to_tf(
local_shuffle_seed: The seed to use for the local random shuffle.

Returns:
A tf.data.Dataset.
A ``tf.data.Dataset`` that yields inputs and targets.
"""

# argument exception checking is done in from_generator
from ray.air._internal.tensorflow_utils import get_type_spec
from ray.train.tensorflow import prepare_dataset_shard

try:
import tensorflow as tf
except ImportError:
raise ValueError("tensorflow must be installed!")

from ray.air._internal.tensorflow_utils import convert_pandas_to_tf_tensor
if self._dataset_format() == "simple":
raise NotImplementedError(
"`to_tf` doesn't support simple datasets. Call `map_batches` and "
"convert your data to a tabular format. Alternatively, call the more-"
"flexible `iter_batches` in place of `to_tf`."
)

# `output_signature` can be a tuple but not a list. See
# https://stackoverflow.com/questions/59092423/what-is-a-nested-structure-in-tensorflow.
if isinstance(output_signature, list):
output_signature = tuple(output_signature)
def validate_column(column: str) -> None:
valid_columns = self.schema().names
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
if column not in valid_columns:
raise ValueError(
f"You specified '{column}' in `feature_columns` or "
f"`label_columns`, but there's no column named '{column}' in the "
f"dataset. Valid column names are: {valid_columns}."
)

def make_generator():
for batch in self.iter_batches(
def validate_columns(columns: Union[str, List]) -> None:
if isinstance(columns, list):
for column in columns:
validate_column(column)
else:
validate_column(columns)

validate_columns(feature_columns)
validate_columns(label_columns)

def get_columns_from_batch(
batch: Dict[str, tf.Tensor], *, columns: Union[str, List[str]]
) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
assert isinstance(columns, (str, list))
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(columns, str):
return batch[columns]
return {column: batch[column] for column in columns}

def generator():
for batch in self.iter_tf_batches(
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
batch_format="pandas",
drop_last=drop_last,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
):
if label_column:
targets = convert_pandas_to_tf_tensor(batch[[label_column]])
if targets.ndim == 2 and targets.shape[1] == 1:
targets = tf.squeeze(targets, axis=1)
batch.pop(label_column)
assert isinstance(batch, dict)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
features = get_columns_from_batch(batch, columns=feature_columns)
labels = get_columns_from_batch(batch, columns=label_columns)
yield features, labels

features = None
if feature_columns is None:
features = convert_pandas_to_tf_tensor(batch)
elif isinstance(feature_columns, list):
if all(isinstance(column, str) for column in feature_columns):
features = convert_pandas_to_tf_tensor(batch[feature_columns])
elif all(isinstance(columns, list) for columns in feature_columns):
features = tuple(
convert_pandas_to_tf_tensor(batch[columns])
for columns in feature_columns
)
else:
raise ValueError(
"Expected `feature_columns` to be a list of strings or a "
"list of lists."
)
elif isinstance(feature_columns, dict):
features = {
key: convert_pandas_to_tf_tensor(batch[columns])
for key, columns in feature_columns.items()
}
else:
raise ValueError(
"Expected `feature_columns` to be a list or a dictionary, "
f"but got a `{type(feature_columns).__name__}` instead."
)

if label_column:
yield features, targets
else:
yield features
feature_type_spec = get_type_spec(self.schema(), columns=feature_columns)
label_type_spec = get_type_spec(self.schema(), columns=label_columns)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
output_signature = (feature_type_spec, label_type_spec)

dataset = tf.data.Dataset.from_generator(
make_generator, output_signature=output_signature
generator, output_signature=output_signature
)

return dataset
return prepare_dataset_shard(dataset)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

def to_dask(
self,
Expand Down
Loading