Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Extend API to enable passing sample weights via ray.dataset.to_tf (#45701) #46784

Merged
merged 14 commits into from
Aug 5, 2024

Conversation

jeffreyjeffreywang
Copy link
Contributor

@jeffreyjeffreywang jeffreyjeffreywang commented Jul 25, 2024

Why are these changes needed?

tensorflow.keras.model.fit supports callers to pass in sample weights along with input dataset (https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit). However, ray.data.Dataset.to_tf() does not support this feature. Currently, callers can only generate data instances comprised of features and labels.

This PR extends ray.data.Dataset.to_tf() to allow callers to specify additional metadata associated with data samples. Specifically, additional_columns is introduced as a forced kwargs to ray.data.Dataset.to_tf() and ray.DataIterator.to_tf(). When additional_columns is not specified, there is no change to the API behavior. By contrast, while additional_columns is provided, the APIs will yield additional_metadata along with features and labels. To leverage the functionality for passing sample weights along with sample features and labels in tensorflow, one can create tf.data.Dataset from ray.data.Dataset.to_tf() and specifying additional_columns="weight" as weight is the column for storing sample weights in ray.data.Dataset.

We do not explicitly name the new argument sample_weight or weight and do not limit its type to be str as there may be other metadata associated with each sample that we want to yield while iterating through the dataset.

We follow the heuristic of inferring feature_type_spec and label_type_spec when determining additional_type_spec. Existing tests are extended to validate when the APIs are invoked with and without the additional_columns argument to ensure existing callers won't fail to invoke ray.data.Dataset.to_tf() and ray.DataIterator.to_tf() as this new change rolls out.

Related issue number

Resolves #45701

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Copy link
Contributor

@scottjlee scottjlee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating all the tests!

can we also add an additional unit test, where we use Dataset.to_tf() with sample weights passed, then we call keras.Model.fit() with the sample weights? so that we can verify that the original intended behavior

python/ray/data/iterator.py Outdated Show resolved Hide resolved
python/ray/data/iterator.py Outdated Show resolved Hide resolved
python/ray/data/iterator.py Show resolved Hide resolved
python/ray/data/iterator.py Outdated Show resolved Hide resolved
jeffreyjeffreywang added 2 commits July 26, 2024 23:12
Signed-off-by: jeffreyjeffreywang <[email protected]>
@jeffreyjeffreywang
Copy link
Contributor Author

@scottjlee Would appreciate your quick review when you have time, no rush 😄

Copy link
Contributor

@scottjlee scottjlee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution!

python/ray/data/iterator.py Show resolved Hide resolved
python/ray/data/iterator.py Show resolved Hide resolved
python/ray/data/iterator.py Outdated Show resolved Hide resolved
@bveeramani bveeramani removed their assignment Jul 31, 2024
Copy link
Contributor

@omatthew98 omatthew98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm, thanks for the contribution!

@scottjlee
Copy link
Contributor

@jeffreyjeffreywang let me know when the docs example is added, and i can take a final look before we get this merged

@jeffreyjeffreywang
Copy link
Contributor Author

Hey @scottjlee, definitely, I plan to get it done today.

jeffreyjeffreywang added 3 commits August 1, 2024 23:38
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
@jeffreyjeffreywang
Copy link
Contributor Author

Hey @scottjlee, I've added some documentation to explain the spec arguments and the new argument I introduced. Doc examples are validated with pytest --dodctest-modules and built locally with make develop in the doc directory. Here are the sections/examples I added. Please let me know if these look good to you 😄

ray.data.Dataset.to_tf

Screenshot 2024-08-01 at 6 00 37 PM

ray.data.DataIterator.to_tf

Screenshot 2024-08-01 at 6 01 59 PM

@scottjlee
Copy link
Contributor

Thanks, the docs changes look good to me. I am merging in latest master to resolve the build issue (unrelated to this PR)

@scottjlee scottjlee added the go add ONLY when ready to merge, run all tests label Aug 3, 2024
@jeffreyjeffreywang
Copy link
Contributor Author

Thanks Scott! Looking forward to contributing more 😃

@bveeramani bveeramani merged commit 06fb5fc into ray-project:master Aug 5, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Data] Add API support for passing sample weights via ray.dataset.to_tf
4 participants