Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX RNN backend issue. #924

Merged
merged 1 commit into from
Sep 19, 2023
Merged

Fix JAX RNN backend issue. #924

merged 1 commit into from
Sep 19, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Sep 19, 2023

This PR address several issues:

  1. The existing RNN layer is not training properly due the usage of a fresh StatelessScope in the jax.lax.scan loop. This is causing all the trainable variables to miss the mapping to the actual value in the training loop. Update them to use the parent Stateless scope if it is there. This will address the training issue Example nlp/lstm_seq2seq.py doesn't train with JAX backend #322

  2. The RNN layers with dropout will have a RNG seed update in the step function, which is not allowed by the jax.lax.scan. We noticed this issue since the updated seed is traced for non-trainable variable, and raise error when we try to put sharding constraint for distribution. Added a new method to pre-populate the dropout mask on the layer and make the inner_loop to be stateless.

  3. During the unit test, I noticed the stackRNNCell doesn't work with existing RNNCell, since it unwrap the list for the state, make the call function to keep the list if the input state is a list.

  4. Expose the SimpleRNN|GRU|LSTM cells in the init.py since they are public API.

@codecov
Copy link

codecov bot commented Sep 19, 2023

Codecov Report

Patch coverage: 80.95% and project coverage change: +0.03% 🎉

Comparison is base (c64de55) 79.73% compared to head (a340b01) 79.76%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #924      +/-   ##
==========================================
+ Coverage   79.73%   79.76%   +0.03%     
==========================================
  Files         318      318              
  Lines       28627    28645      +18     
  Branches     5447     5451       +4     
==========================================
+ Hits        22827    22850      +23     
+ Misses       4333     4332       -1     
+ Partials     1467     1463       -4     
Flag Coverage Δ
keras_core 79.69% <80.95%> (+0.03%) ⬆️
keras_core-numpy 60.40% <80.95%> (+0.01%) ⬆️
keras_core-tensorflow 66.84% <71.42%> (+0.02%) ⬆️
keras_core-torch 69.25% <71.42%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/jax/rnn.py 9.09% <33.33%> (+0.54%) ⬆️
keras_core/layers/__init__.py 96.00% <100.00%> (+0.09%) ⬆️
keras_core/layers/rnn/rnn.py 85.98% <100.00%> (+0.95%) ⬆️
keras_core/layers/rnn/stacked_rnn_cells.py 87.01% <100.00%> (+5.43%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@fchollet fchollet merged commit f2c3766 into keras-team:main Sep 19, 2023
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants