Skip to content

Commit

Permalink
[Datasets] Fix to_tf when the input feature_columns is a list. (ray…
Browse files Browse the repository at this point in the history
…-project#31228)

This fixes a bug in dataset.to_tf method. This method should returns a dictionary when input feature_columns is a list, but in the current release 2.2 and master branch it is returning a set, which also fails due to unhashable type.

Signed-off-by: tmynn <[email protected]>
  • Loading branch information
n30111 authored and tamohannes committed Jan 25, 2023
1 parent d6c7a06 commit f063e91
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3204,7 +3204,9 @@ def convert_batch_to_tensors(
if isinstance(columns, str):
return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec)
return {
convert_ndarray_to_tf_tensor(batch[column], type_spec=type_spec[column])
column: convert_ndarray_to_tf_tensor(
batch[column], type_spec=type_spec[column]
)
for column in columns
}

Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/tests/test_dataset_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def test_element_spec_type_with_multiple_columns(self):
for value in feature_output_signature.values()
)

df = pd.DataFrame(
{"feature1": [0, 1, 2], "feature2": [3, 4, 5], "label": [0, 1, 1]}
)
ds = ray.data.from_pandas(df)
dataset = ds.to_tf(
feature_columns=["feature1", "feature2"],
label_columns="label",
batch_size=3,
)
feature_output_signature, _ = dataset.element_spec
assert isinstance(feature_output_signature, dict)
assert feature_output_signature.keys() == {"feature1", "feature2"}
assert all(
isinstance(value, tf.TypeSpec)
for value in feature_output_signature.values()
)
features, labels = next(iter(dataset))
assert (labels.numpy() == df["label"].values).all()
assert (features["feature1"].numpy() == df["feature1"].values).all()
assert (features["feature2"].numpy() == df["feature2"].values).all()

def test_element_spec_name(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0}])

Expand Down

0 comments on commit f063e91

Please sign in to comment.