Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax[cuda] installation replaces current jax version with old jax-0.2.22 version #12307

Closed
gianlucadetommaso opened this issue Sep 9, 2022 · 15 comments
Assignees
Labels
bug Something isn't working

Comments

@gianlucadetommaso
Copy link

Description

When I run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
on my laptop, the latest version of jax, currently jax-0.3.17, gets replaced by jax-0.2.22. I don't think this was happening before. Here is a representative output of the command above:

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Using cached jax-0.3.17.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.1.tar.gz (912 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.0.tar.gz (896 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.28.tar.gz (887 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.27.tar.gz (873 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.26.tar.gz (850 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
WARNING: jax 0.2.22 does not provide the extra 'cuda'

What jax/jaxlib version are you using?

jax-0.3.17

Which accelerator(s) are you using?

GPU

Additional System Info

MAC

@gianlucadetommaso gianlucadetommaso added the bug Something isn't working label Sep 9, 2022
@gianlucadetommaso gianlucadetommaso changed the title jax[cuda] replaces current jax version with old jax-0.2.22 installation jax[cuda] installation replaces current jax version with old jax-0.2.22 installation Sep 9, 2022
@gianlucadetommaso gianlucadetommaso changed the title jax[cuda] installation replaces current jax version with old jax-0.2.22 installation jax[cuda] installation replaces current jax version with old jax-0.2.22 version Sep 9, 2022
@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 9, 2022

The URL should be https://storage.googleapis.com/jax-releases/jax_releases.html

@yashk2810 Why is jax_cuda_releases.html even still present? Wouldn't it be better to delete it to prevent any problems like this?

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 9, 2022

Actually I'm wrong! You had the right URL the first time. I think the index was in a broken state. Try now?

@yashk2810 Is it possible the index file doesn't get updated atomically?

@gianlucadetommaso
Copy link
Author

@hawkinsp still getting the same output!

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 9, 2022

My guess is that's related to some sort of caching of the index and it will fix itself soon. If you open it up in a web browser, do you see:

cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp310-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp38-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp39-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn82-cp310-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn82-cp37-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn82-cp38-none-manylinux2014_x86_64.whl
cuda11/jaxlib-0.3.15+cuda11.cudnn82-cp39-none-manylinux2014_x86_64.whl

in the list?

You can always download the necessary wheel manually.

@gianlucadetommaso
Copy link
Author

Yes, they are all there.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 9, 2022

Oh! Wait! You are on a Mac. We don't support CUDA on Mac. So there's no matching wheel found.

You should install the CPU wheels on Mac (i.e., just pip install jaxlib).

@hawkinsp hawkinsp closed this as completed Sep 9, 2022
@treyra
Copy link
Contributor

treyra commented Oct 27, 2022

I'm seeing the same issue on Ubuntu as well. Most recently is the behavior below on a new Jetson Orin Dev Kit (so it is Linux for Tortuga)

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Requirement already satisfied: jax[cuda] in ./.local/lib/python3.8/site-packages (0.3.23)
Requirement already satisfied: etils[epath] in ./.local/lib/python3.8/site-packages (from jax[cuda]) (0.8.0)
Requirement already satisfied: absl-py in ./.local/lib/python3.8/site-packages (from jax[cuda]) (1.3.0)
Requirement already satisfied: scipy>=1.5 in ./.local/lib/python3.8/site-packages (from jax[cuda]) (1.9.3)
Requirement already satisfied: numpy>=1.20 in ./.local/lib/python3.8/site-packages (from jax[cuda]) (1.23.4)
Requirement already satisfied: opt-einsum in ./.local/lib/python3.8/site-packages (from jax[cuda]) (3.3.0)
Requirement already satisfied: typing-extensions in ./.local/lib/python3.8/site-packages (from jax[cuda]) (4.4.0)
Collecting jax[cuda]
  Using cached jax-0.3.23-py3-none-any.whl
  Using cached jax-0.3.22.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.21.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.20.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.19.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.17.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.1.tar.gz (912 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.0.tar.gz (896 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.28.tar.gz (887 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.27.tar.gz (873 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.26.tar.gz (850 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.22-py3-none-any.whl
WARNING: jax 0.2.22 does not provide the extra 'cuda'
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.3.23
    Uninstalling jax-0.3.23:
      Successfully uninstalled jax-0.3.23
Successfully installed jax-0.2.22

But if I instead provide my versions of cuda (11.4) and cuddnn (8.3.2). I get the expected behavior:


pip install --upgrade "jax[cuda114_cudnn832]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Requirement already satisfied: jax[cuda114_cudnn832] in ./.local/lib/python3.8/site-packages (0.2.22)
Collecting jax[cuda114_cudnn832]
  Using cached jax-0.3.23-py3-none-any.whl
WARNING: jax 0.3.23 does not provide the extra 'cuda114_cudnn832'
Requirement already satisfied: scipy>=1.5 in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (1.9.3)
Requirement already satisfied: typing-extensions in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (4.4.0)
Requirement already satisfied: absl-py in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (1.3.0)
Requirement already satisfied: opt-einsum in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (3.3.0)
Requirement already satisfied: numpy>=1.20 in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (1.23.4)
Requirement already satisfied: etils[epath] in ./.local/lib/python3.8/site-packages (from jax[cuda114_cudnn832]) (0.8.0)
Requirement already satisfied: importlib_resources in ./.local/lib/python3.8/site-packages (from etils[epath]->jax[cuda114_cudnn832]) (5.10.0)
Requirement already satisfied: zipp in ./.local/lib/python3.8/site-packages (from etils[epath]->jax[cuda114_cudnn832]) (3.10.0)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.2.22
    Uninstalling jax-0.2.22:
      Successfully uninstalled jax-0.2.22
Successfully installed jax-0.3.23

If I just just pip install --upgrade jax I get the same result

pip install --upgrade jax
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: jax in ./.local/lib/python3.8/site-packages (0.2.22)
Collecting jax
  Using cached jax-0.3.23-py3-none-any.whl
Requirement already satisfied: absl-py in ./.local/lib/python3.8/site-packages (from jax) (1.3.0)
Requirement already satisfied: opt-einsum in ./.local/lib/python3.8/site-packages (from jax) (3.3.0)
Requirement already satisfied: etils[epath] in ./.local/lib/python3.8/site-packages (from jax) (0.8.0)
Requirement already satisfied: typing-extensions in ./.local/lib/python3.8/site-packages (from jax) (4.4.0)
Requirement already satisfied: scipy>=1.5 in ./.local/lib/python3.8/site-packages (from jax) (1.9.3)
Requirement already satisfied: numpy>=1.20 in ./.local/lib/python3.8/site-packages (from jax) (1.23.4)
Requirement already satisfied: zipp in ./.local/lib/python3.8/site-packages (from etils[epath]->jax) (3.10.0)
Requirement already satisfied: importlib_resources in ./.local/lib/python3.8/site-packages (from etils[epath]->jax) (5.10.0)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.2.22
    Uninstalling jax-0.2.22:
      Successfully uninstalled jax-0.2.22
Successfully installed jax-0.3.23

Is this just that pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html grabs the version of jax with the least restrictive requirements? It seems in either case the extra I provide isn't recognized, but with "[cuda]" it defaults to an old version, and with "[cuda114_cudnn832]" or no extra at all it defaults to the newest install.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2022

@treyra What version of Python are you using?

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2022

Also, what happens if you run this?

$ pip install jaxlib==0.3.22+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@treyra
Copy link
Contributor

treyra commented Oct 27, 2022

Thanks for replying so fast! I'm using python 3.8.10. (And pip is 22.3)

Running that command gives me:

pip install jaxlib==0.3.22+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.22+cuda11.cudnn82 (from versions: none)
ERROR: No matching distribution found for jaxlib==0.3.22+cuda11.cudnn82

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2022

That wheel is in the index and the command works on Colab. So there must be some quirk in your own system. A couple guesses:

  • perhaps you have a firewall that's blocking the jax releases URL?
  • perhaps your flavor of linux is not compatible with manylinux2010 wheels?
  • perhaps your system is not compatible with x86_64 wheels?

@treyra
Copy link
Contributor

treyra commented Oct 27, 2022

Thank you, that might be the issue. When I try to just pip install jax, it does install jax 0.3.23, but not jaxlib along side it. I looks like from #7097 jax isn't compatible with ARM CPUs? The Nvidia Jetson Orin uses an Arm Cortex CPU. I'm guessing that is why above it says you don't support mac?

Maybe this should be a separate issue or mentioned on #7097, but I didn't see that jax wheels are only available of x86_64 architectures until digging into the issues, is that something that makes sense to add to the installation guide?

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2022

Yes, if the steps here didn't help with your problem, we should add a warning there about the issue you ran into. I'm not sure about the status of jaxlib wheel releases with ARM CPUs, but it sound like that may be the culprit.

@treyra
Copy link
Contributor

treyra commented Oct 27, 2022

Thank you for your help with this, I think that is the main issue. I found #7186 about compiling jax for an older version of the Jetson series, so I'll try following that.

The steps you linked were what I was following when I ran into this, I think adding a warning about ARM CPUs would fix the confusion. I can try submitting a pull request on that if you would like.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2022

Yes, a PR would be much appreciated. Thanks!

treyra added a commit to treyra/jax that referenced this issue Oct 27, 2022
Currently non x86_64 linux architectures are not supported, see jax-ml#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see jax-ml#12307. 

If this can be more clearly phrased or explained, let me know.
treyra added a commit to treyra/jax that referenced this issue Oct 27, 2022
Currently non x86_64 linux architectures are not supported, see jax-ml#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see jax-ml#12307. 

If this can be more clearly phrased or explained, let me know.
treyra added a commit to treyra/jax that referenced this issue Nov 4, 2022
Currently non x86_64 linux architectures are not supported, see jax-ml#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see jax-ml#12307. This adds a note to the install sections for the relevant pip wheels.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants