Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example nlp/lstm_seq2seq.py doesn't train with JAX backend #322

Closed
fchollet opened this issue Jun 11, 2023 · 12 comments
Closed

Example nlp/lstm_seq2seq.py doesn't train with JAX backend #322

fchollet opened this issue Jun 11, 2023 · 12 comments
Labels
help wanted Extra attention is needed

Comments

@fchollet
Copy link
Member

fchollet commented Jun 11, 2023

The model trains fine with TF. With JAX, the code runs but trains rather poorly. Need to find out root cause.

In addition, in this example, a model trained with TF cannot be reloaded with JAX.

@fchollet fchollet added the help wanted Extra attention is needed label Jun 12, 2023
@fchollet
Copy link
Member Author

I verified that the two backends give identical forward pass numerics for the model.

The saving issue was unrelated and I've now fixed it.

So there must be something different about either initializers or the optimization process somehow.

@fchollet
Copy link
Member Author

I ruled out initialization. So it can only really be an optimizer issue or trainer issue.

@fchollet
Copy link
Member Author

It trains fine in torch as well.

@shivance
Copy link
Collaborator

Hi @fchollet, can I contribute here ?

@fchollet
Copy link
Member Author

Absolutely! It's a pretty hard issue I think. The problem is very non-obvious, everything looks good piecewise. TF and torch train well, but JAX trains very poorly (while still training anyway). The code is here:

https://gist.github.com/fchollet/f0c84ecbed8441e54820df8366a5a629

@ariG23498
Copy link
Collaborator

While working on the code the "tensorflow" backend threw an error.

Here is the Gist to reproduce the error: https://gist.github.com/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943

@fchollet
Copy link
Member Author

The error message tells you that cuDNN can't be compiled to XLA, basically. It's somewhat tricky to solve on our side. Either we disable jit_compile if there's a cuDNN-enabled layer, or we don't use cuDNN if we detect we're tracing for XLA?

@fchollet
Copy link
Member Author

fchollet commented Jul 14, 2023

To work around this you can just pass jit_compile=False to compile() when using TF.

@ariG23498
Copy link
Collaborator

To work around this you can just pass jit_compile=False to compile() when using TF.

Right! Now it trains using tf backend.

The same issue persists when loading the saved model it seems. Do I also need to pass compile=False when loading the saved model (when using TensorFlow as backend)?

@fchollet
Copy link
Member Author

Do I also need to pass compile=False when loading the saved model (when using TensorFlow as backend)?

You can just set jit_compile = False on the model I think.

@fchollet
Copy link
Member Author

@ariG23498 I have fixed it in this commit, please check that it works for you. c9bce12

@qlzh727
Copy link
Member

qlzh727 commented Sep 19, 2023

ok, I was able to verify my fix in #888 to fix this issue.

See https://colab.corp.google.com/drive/1z_QDD0uX9ApLJdFTxhHYxJNejUdgowkG#scrollTo=kxjMd749C1nA. Will send a PR very soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants