-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Add API support for passing sample weights via ray.dataset.to_tf #45701
Comments
@japneet-anyscale could you give more details / sample usage on how the new API would be used? if i understand correctly, |
After internal discussion, got additional context. Current workaround is to create a new generator which yields the Dataset, and insert the sample_weights:
|
We can add a new parameter
then update finally, the API usage would be
|
@scottjlee Would you be able to pick this item in this sprint? |
yeah either me or @omatthew98 can pick it up for the upcoming sprint (assuming in the next 6 weeks) |
Hey @scottjlee, this issue looks interesting to me and I'm more than happy to help resolve this problem in the next couple days if you don't have spare cycles to tackle it. I am also thinking about adding tests to |
@jeffreyjeffreywang that sounds great, go for it! thanks! |
…_tf (#45701) (#46784) `tensorflow.keras.model.fit` supports callers to pass in sample weights along with input dataset (https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit). However, `ray.data.Dataset.to_tf()` does not support this feature. Currently, callers can only generate data instances comprised of features and labels. This PR extends `ray.data.Dataset.to_tf()` to allow callers to specify additional metadata associated with data samples. Specifically, `additional_columns` is introduced as a forced kwargs to `ray.data.Dataset.to_tf()` and `ray.DataIterator.to_tf()`. When `additional_columns` is not specified, there is no change to the API behavior. By contrast, while `additional_columns` is provided, the APIs will yield `additional_metadata` along with `features` and `labels`. To leverage the functionality for passing sample weights along with sample features and labels in tensorflow, one can create `tf.data.Dataset` from `ray.data.Dataset.to_tf()` and specifying `additional_columns="weight"` as `weight` is the column for storing sample weights in `ray.data.Dataset`. We do not explicitly name the new argument `sample_weight` or `weight` and do not limit its type to be `str` as there may be other metadata associated with each sample that we want to yield while iterating through the dataset. We follow the heuristic of inferring `feature_type_spec` and `label_type_spec` when determining `additional_type_spec`. Existing tests are extended to validate when the APIs are invoked with and without the `additional_columns` argument to ensure existing callers won't fail to invoke `ray.data.Dataset.to_tf()` and `ray.DataIterator.to_tf()` as this new change rolls out. ## Related issue number Resolves #45701 --------- Signed-off-by: jeffreyjeffreywang <[email protected]> Co-authored-by: jeffreyjeffreywang <[email protected]> Co-authored-by: Scott Lee <[email protected]>
Description
Add support for ray.dataset.to_tf to pass sample weights
https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit
API does not currently support the feature:
https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.to_tf.html#ray-data-dataset-to-tf
Use case
No response
The text was updated successfully, but these errors were encountered: