Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Add API support for passing sample weights via ray.dataset.to_tf #45701

Closed
japneet-anyscale opened this issue Jun 3, 2024 · 7 comments · Fixed by #46784
Closed

[Data] Add API support for passing sample weights via ray.dataset.to_tf #45701

japneet-anyscale opened this issue Jun 3, 2024 · 7 comments · Fixed by #46784
Labels
data Ray Data-related issues enhancement Request for new feature and/or capability good first issue Great starter issue for someone just starting to contribute to Ray P1 Issue that should be fixed within a few weeks

Comments

@japneet-anyscale
Copy link

Description

Add support for ray.dataset.to_tf to pass sample weights

https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit

API does not currently support the feature:
https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.to_tf.html#ray-data-dataset-to-tf

Use case

No response

@japneet-anyscale japneet-anyscale added enhancement Request for new feature and/or capability triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jun 3, 2024
@anyscalesam anyscalesam added data Ray Data-related issues and removed enhancement Request for new feature and/or capability labels Jun 3, 2024
@scottjlee scottjlee added enhancement Request for new feature and/or capability P1 Issue that should be fixed within a few weeks and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jun 26, 2024
@scottjlee
Copy link
Contributor

scottjlee commented Jun 26, 2024

@japneet-anyscale could you give more details / sample usage on how the new API would be used?

if i understand correctly, Model.fit(...) is for a Keras model, while the Dataset.to_tf(...) will produce a TF Dataset. So the two methods will be for different types (dataset vs. for fitting a model)

@scottjlee
Copy link
Contributor

After internal discussion, got additional context. Current workaround is to create a new generator which yields the Dataset, and insert the sample_weights:

def generator():
                for batch in self.dataset.iter_batches(
                    batch_size=batch_size,
                ):
                    assert isinstance(batch, dict)
                    features = convert_batch_to_tensors(
                        batch, columns=feature_columns, type_spec=feature_type_spec
                    )
                    labels = convert_batch_to_tensors(
                        batch, columns=label_columns, type_spec=label_type_spec
                    )
                    sample_weights = convert_batch_to_tensors(
                        batch, columns=sample_weight_column, type_spec=sample_weight_spec
                    )
                    yield features, labels, sample_weights

@scottjlee
Copy link
Contributor

We can add a new parameter additional_type_spec to to_tf():

def to_tf(
    ...,
    additional_type_spec: Dict[str, "tf.TypeSpec"] = None
)

then update generator() to output sample_weights, similar to the example in the previous message above: https://github.com/ray-project/ray/blob/releases/2.11.0/python/ray/data/iterator.py#L820-L835

finally, the API usage would be

to_tf(..., additional_type_spec={sample_weight_column: sample_weight_spec})

@scottjlee scottjlee added the good first issue Great starter issue for someone just starting to contribute to Ray label Jun 26, 2024
@galyna-anyscale
Copy link

@scottjlee Would you be able to pick this item in this sprint?

@scottjlee
Copy link
Contributor

yeah either me or @omatthew98 can pick it up for the upcoming sprint (assuming in the next 6 weeks)

@jeffreyjeffreywang
Copy link
Contributor

Hey @scottjlee, this issue looks interesting to me and I'm more than happy to help resolve this problem in the next couple days if you don't have spare cycles to tackle it. I am also thinking about adding tests to ray/python/ray/data/tests/test_tf.py as we introduce a new parameter to to_tf. Let me know if you prefer taking it over, or otherwise I'll publish a PR in the next couple days! Thanks!

@scottjlee
Copy link
Contributor

@jeffreyjeffreywang that sounds great, go for it! thanks!

bveeramani pushed a commit that referenced this issue Aug 5, 2024
…_tf (#45701) (#46784)

`tensorflow.keras.model.fit` supports callers to pass in sample weights
along with input dataset
(https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/keras/Model#fit).
However, `ray.data.Dataset.to_tf()` does not support this feature.
Currently, callers can only generate data instances comprised of
features and labels.

This PR extends `ray.data.Dataset.to_tf()` to allow callers to specify
additional metadata associated with data samples. Specifically,
`additional_columns` is introduced as a forced kwargs to
`ray.data.Dataset.to_tf()` and `ray.DataIterator.to_tf()`. When
`additional_columns` is not specified, there is no change to the API
behavior. By contrast, while `additional_columns` is provided, the APIs
will yield `additional_metadata` along with `features` and `labels`. To
leverage the functionality for passing sample weights along with sample
features and labels in tensorflow, one can create `tf.data.Dataset` from
`ray.data.Dataset.to_tf()` and specifying `additional_columns="weight"`
as `weight` is the column for storing sample weights in
`ray.data.Dataset`.

We do not explicitly name the new argument `sample_weight` or `weight`
and do not limit its type to be `str` as there may be other metadata
associated with each sample that we want to yield while iterating
through the dataset.

We follow the heuristic of inferring `feature_type_spec` and
`label_type_spec` when determining `additional_type_spec`. Existing
tests are extended to validate when the APIs are invoked with and
without the `additional_columns` argument to ensure existing callers
won't fail to invoke `ray.data.Dataset.to_tf()` and
`ray.DataIterator.to_tf()` as this new change rolls out.

## Related issue number
Resolves #45701 
---------

Signed-off-by: jeffreyjeffreywang <[email protected]>
Co-authored-by: jeffreyjeffreywang <[email protected]>
Co-authored-by: Scott Lee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data Ray Data-related issues enhancement Request for new feature and/or capability good first issue Great starter issue for someone just starting to contribute to Ray P1 Issue that should be fixed within a few weeks
Projects
None yet
5 participants