-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Extend API to enable passing sample weights via ray.dataset.to_tf (#45701) #46784
[Data] Extend API to enable passing sample weights via ray.dataset.to_tf (#45701) #46784
Conversation
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
…luding additional column Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updating all the tests!
can we also add an additional unit test, where we use Dataset.to_tf()
with sample weights passed, then we call keras.Model.fit()
with the sample weights? so that we can verify that the original intended behavior
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
@scottjlee Would appreciate your quick review when you have time, no rush 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the contribution!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm, thanks for the contribution!
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
@jeffreyjeffreywang let me know when the docs example is added, and i can take a final look before we get this merged |
Hey @scottjlee, definitely, I plan to get it done today. |
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
Signed-off-by: jeffreyjeffreywang <[email protected]>
Hey @scottjlee, I've added some documentation to explain the spec arguments and the new argument I introduced. Doc examples are validated with ray.data.Dataset.to_tfray.data.DataIterator.to_tf |
Thanks, the docs changes look good to me. I am merging in latest master to resolve the build issue (unrelated to this PR) |
Thanks Scott! Looking forward to contributing more 😃 |
Why are these changes needed?
tensorflow.keras.model.fit
supports callers to pass in sample weights along with input dataset (https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit). However,ray.data.Dataset.to_tf()
does not support this feature. Currently, callers can only generate data instances comprised of features and labels.This PR extends
ray.data.Dataset.to_tf()
to allow callers to specify additional metadata associated with data samples. Specifically,additional_columns
is introduced as a forced kwargs toray.data.Dataset.to_tf()
andray.DataIterator.to_tf()
. Whenadditional_columns
is not specified, there is no change to the API behavior. By contrast, whileadditional_columns
is provided, the APIs will yieldadditional_metadata
along withfeatures
andlabels
. To leverage the functionality for passing sample weights along with sample features and labels in tensorflow, one can createtf.data.Dataset
fromray.data.Dataset.to_tf()
and specifyingadditional_columns="weight"
asweight
is the column for storing sample weights inray.data.Dataset
.We do not explicitly name the new argument
sample_weight
orweight
and do not limit its type to bestr
as there may be other metadata associated with each sample that we want to yield while iterating through the dataset.We follow the heuristic of inferring
feature_type_spec
andlabel_type_spec
when determiningadditional_type_spec
. Existing tests are extended to validate when the APIs are invoked with and without theadditional_columns
argument to ensure existing callers won't fail to invokeray.data.Dataset.to_tf()
andray.DataIterator.to_tf()
as this new change rolls out.Related issue number
Resolves #45701
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.