-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example nlp/lstm_seq2seq.py
doesn't train with JAX backend
#322
Comments
I verified that the two backends give identical forward pass numerics for the model. The saving issue was unrelated and I've now fixed it. So there must be something different about either initializers or the optimization process somehow. |
I ruled out initialization. So it can only really be an optimizer issue or trainer issue. |
It trains fine in torch as well. |
Hi @fchollet, can I contribute here ? |
Absolutely! It's a pretty hard issue I think. The problem is very non-obvious, everything looks good piecewise. TF and torch train well, but JAX trains very poorly (while still training anyway). The code is here: https://gist.github.com/fchollet/f0c84ecbed8441e54820df8366a5a629 |
While working on the code the "tensorflow" backend threw an error. Here is the Gist to reproduce the error: https://gist.github.com/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943 |
The error message tells you that cuDNN can't be compiled to XLA, basically. It's somewhat tricky to solve on our side. Either we disable jit_compile if there's a cuDNN-enabled layer, or we don't use cuDNN if we detect we're tracing for XLA? |
To work around this you can just pass |
Right! Now it trains using tf backend. The same issue persists when loading the saved model it seems. Do I also need to pass |
You can just set |
@ariG23498 I have fixed it in this commit, please check that it works for you. c9bce12 |
ok, I was able to verify my fix in #888 to fix this issue. See https://colab.corp.google.com/drive/1z_QDD0uX9ApLJdFTxhHYxJNejUdgowkG#scrollTo=kxjMd749C1nA. Will send a PR very soon. |
The model trains fine with TF. With JAX, the code runs but trains rather poorly. Need to find out root cause.
In addition, in this example, a model trained with TF cannot be reloaded with JAX.
The text was updated successfully, but these errors were encountered: