-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Data] Extend API to enable passing sample weights via ray.dataset.to…
…_tf (#45701) (#46784) `tensorflow.keras.model.fit` supports callers to pass in sample weights along with input dataset (https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit). However, `ray.data.Dataset.to_tf()` does not support this feature. Currently, callers can only generate data instances comprised of features and labels. This PR extends `ray.data.Dataset.to_tf()` to allow callers to specify additional metadata associated with data samples. Specifically, `additional_columns` is introduced as a forced kwargs to `ray.data.Dataset.to_tf()` and `ray.DataIterator.to_tf()`. When `additional_columns` is not specified, there is no change to the API behavior. By contrast, while `additional_columns` is provided, the APIs will yield `additional_metadata` along with `features` and `labels`. To leverage the functionality for passing sample weights along with sample features and labels in tensorflow, one can create `tf.data.Dataset` from `ray.data.Dataset.to_tf()` and specifying `additional_columns="weight"` as `weight` is the column for storing sample weights in `ray.data.Dataset`. We do not explicitly name the new argument `sample_weight` or `weight` and do not limit its type to be `str` as there may be other metadata associated with each sample that we want to yield while iterating through the dataset. We follow the heuristic of inferring `feature_type_spec` and `label_type_spec` when determining `additional_type_spec`. Existing tests are extended to validate when the APIs are invoked with and without the `additional_columns` argument to ensure existing callers won't fail to invoke `ray.data.Dataset.to_tf()` and `ray.DataIterator.to_tf()` as this new change rolls out. ## Related issue number Resolves #45701 --------- Signed-off-by: jeffreyjeffreywang <[email protected]> Co-authored-by: jeffreyjeffreywang <[email protected]> Co-authored-by: Scott Lee <[email protected]>
- Loading branch information
1 parent
1a02e3c
commit 06fb5fc
Showing
4 changed files
with
388 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.