Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for RaggedTensors for TensorFlow backend #385

Merged
merged 8 commits into from
Jun 26, 2023

Conversation

tirthasheshpatel
Copy link
Contributor

Keras Core didn't allow passing RaggedTensors to methods of keras_core.Model even when using the Tensorflow backend. Since many models in Keras CV and Keras NLP have ragged tensors ingrained in thier APIs, it would be good to offer support for them at least when using Tensorflow as the backend.

@ianstenbit

output = tf.random.stateless_categorical(
logits, num_samples, seed=seed
)
output = tf.random.stateless_categorical(logits, num_samples, seed=seed)
Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change made by the linter, please ignore.

Comment on lines 34 to 55
logits,
num_samples,
replacement=True,
generator=generator,
Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change made by the linter, please ignore.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Do you think we could support sparse tensors as well, without casting them to dense?

@@ -826,3 +826,50 @@ def is_tpu_strat(k):
if is_tpu_strat(clz):
return True
return any(map(_is_tpu_strategy_class, clz.__bases__))


# This function is taken from keras.engine.training.potentially_ragged_concat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment will be soon outdated (Keras Core will become Keras) and general Keras Core used a lot of code from tf.keras, so you can just remove the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

keras_core.backend.backend() != "tensorflow",
reason="Only tensorflow supports raggeds",
)
def test_trainer_with_raggeds(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also test __call__ing a layer with ragged tensors and __calling__ing a Functional model and Sequential model since these are all different code paths

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


model = ExampleModel()
x = tf.ragged.constant([[1], [2, 3]])
model.compile(optimizer="adam")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it behave with losses and metrics? Can we reproduce the tf.keras behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it behave with losses and metrics?

Losses and Metrics don't work with RaggedTensors because of the way they are implemented in keras core. They can be supported but I'd leave that to a follow-up PR.

@tirthasheshpatel
Copy link
Contributor Author

Do you think we could support sparse tensors as well, without casting them to dense?

Will look into this. If possible, will create a follow-up PR.


class ExampleModel(base_class):
def __init__(self, input_shape=(None,)):
if base_class is keras_core.Functional:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work for testing Functional because the call method is overridden. You'll need to use model = Functional(...) directly, without overriding call.

In this case you'll want to use a custom layer that supports Ragged (unless we already have layers that support Ragged...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Right now, the tfnp API doesn't support RaggedTensors, so, the tests just make sure that keras_core allows raggeds as inputs and the outputs remain ragged. Let me know if you had something else in mind.

"""Concats `Tensor`s along their first dimension.

Args:
tensors: List of `Tensor`s.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 4 space indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# model. We just test that passing RaggedTensors works for
# now.
outputs = inputs
super().__init__(inputs=inputs, outputs=outputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to use at least 1 layer to check it actually works. Just make a custom layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@fchollet fchollet merged commit 1bf82c3 into main Jun 26, 2023
5 checks passed
@tirthasheshpatel tirthasheshpatel deleted the add-ragged-support branch June 26, 2023 23:16
@tirthasheshpatel
Copy link
Contributor Author

Thanks for the reviews @fchollet!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants