diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4442ba449eb5..ab6774a87476 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3204,7 +3204,9 @@ def convert_batch_to_tensors( if isinstance(columns, str): return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec) return { - convert_ndarray_to_tf_tensor(batch[column], type_spec=type_spec[column]) + column: convert_ndarray_to_tf_tensor( + batch[column], type_spec=type_spec[column] + ) for column in columns } diff --git a/python/ray/data/tests/test_dataset_tf.py b/python/ray/data/tests/test_dataset_tf.py index 6ff6a37015c1..235fdee19354 100644 --- a/python/ray/data/tests/test_dataset_tf.py +++ b/python/ray/data/tests/test_dataset_tf.py @@ -46,6 +46,27 @@ def test_element_spec_type_with_multiple_columns(self): for value in feature_output_signature.values() ) + df = pd.DataFrame( + {"feature1": [0, 1, 2], "feature2": [3, 4, 5], "label": [0, 1, 1]} + ) + ds = ray.data.from_pandas(df) + dataset = ds.to_tf( + feature_columns=["feature1", "feature2"], + label_columns="label", + batch_size=3, + ) + feature_output_signature, _ = dataset.element_spec + assert isinstance(feature_output_signature, dict) + assert feature_output_signature.keys() == {"feature1", "feature2"} + assert all( + isinstance(value, tf.TypeSpec) + for value in feature_output_signature.values() + ) + features, labels = next(iter(dataset)) + assert (labels.numpy() == df["label"].values).all() + assert (features["feature1"].numpy() == df["feature1"].values).all() + assert (features["feature2"].numpy() == df["feature2"].values).all() + def test_element_spec_name(self): ds = ray.data.from_items([{"spam": 0, "ham": 0}])