-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for RaggedTensors for TensorFlow backend #385
Conversation
output = tf.random.stateless_categorical( | ||
logits, num_samples, seed=seed | ||
) | ||
output = tf.random.stateless_categorical(logits, num_samples, seed=seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated change made by the linter, please ignore.
keras_core/backend/torch/random.py
Outdated
logits, | ||
num_samples, | ||
replacement=True, | ||
generator=generator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated change made by the linter, please ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Do you think we could support sparse tensors as well, without casting them to dense?
@@ -826,3 +826,50 @@ def is_tpu_strat(k): | |||
if is_tpu_strat(clz): | |||
return True | |||
return any(map(_is_tpu_strategy_class, clz.__bases__)) | |||
|
|||
|
|||
# This function is taken from keras.engine.training.potentially_ragged_concat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment will be soon outdated (Keras Core will become Keras) and general Keras Core used a lot of code from tf.keras, so you can just remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
keras_core/trainers/trainer_test.py
Outdated
keras_core.backend.backend() != "tensorflow", | ||
reason="Only tensorflow supports raggeds", | ||
) | ||
def test_trainer_with_raggeds(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also test __call__
ing a layer with ragged tensors and __calling__
ing a Functional model and Sequential model since these are all different code paths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_core/trainers/trainer_test.py
Outdated
|
||
model = ExampleModel() | ||
x = tf.ragged.constant([[1], [2, 3]]) | ||
model.compile(optimizer="adam") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would it behave with losses and metrics? Can we reproduce the tf.keras behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would it behave with losses and metrics?
Losses and Metrics don't work with RaggedTensor
s because of the way they are implemented in keras core. They can be supported but I'd leave that to a follow-up PR.
Will look into this. If possible, will create a follow-up PR. |
59b4417
to
a951344
Compare
keras_core/trainers/trainer_test.py
Outdated
|
||
class ExampleModel(base_class): | ||
def __init__(self, input_shape=(None,)): | ||
if base_class is keras_core.Functional: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work for testing Functional because the call
method is overridden. You'll need to use model = Functional(...)
directly, without overriding call
.
In this case you'll want to use a custom layer that supports Ragged (unless we already have layers that support Ragged...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Right now, the tfnp API doesn't support RaggedTensor
s, so, the tests just make sure that keras_core
allows raggeds as inputs and the outputs remain ragged. Let me know if you had something else in mind.
a156f67
to
a62bdc9
Compare
"""Concats `Tensor`s along their first dimension. | ||
|
||
Args: | ||
tensors: List of `Tensor`s. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 4 space indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# model. We just test that passing RaggedTensors works for | ||
# now. | ||
outputs = inputs | ||
super().__init__(inputs=inputs, outputs=outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you want to use at least 1 layer to check it actually works. Just make a custom layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
744f360
to
f36c32b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you
Thanks for the reviews @fchollet! |
Keras Core didn't allow passing
RaggedTensor
s to methods ofkeras_core.Model
even when using the Tensorflow backend. Since many models in Keras CV and Keras NLP have ragged tensors ingrained in thier APIs, it would be good to offer support for them at least when using Tensorflow as the backend.@ianstenbit