Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Jax - Flax compatibility error #98

Open
thomashirtz opened this issue Jul 7, 2024 · 3 comments
Open

[BUG] Jax - Flax compatibility error #98

thomashirtz opened this issue Jul 7, 2024 · 3 comments
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@thomashirtz
Copy link

thomashirtz commented Jul 7, 2024

Describe the bug

Hello!

When making the Dockerfile, I get the error Cannot import name 'linear_util' from 'jax' when running examples. This seems to be due to the incompatibility of flax with jax. https://stackoverflow.com/questions/78210393/cannot-import-name-linear-util-from-jax (I do get access to my GPU 2070MaxQ with those settings)

I therefore tried to install the version 4.24 by changing requirements.txt from jax>=0.4.10 to jax>=0.4.24 and the Dockerfile line 36 to :

RUN if [ "$USE_CUDA" = true ] ; \
    then pip install "jax[cuda11]>=0.4.24" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
    fi

however I get the error, not being able to use my gpu anymore :

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Do you have any idea how to solve that ?

Full traceback:

Traceback (most recent call last):
  File "/opt/project/xposure/stoix/systems/q_learning/ff_ddqn.py", line 6, in <module>
    import flashbax as fbx
  File "/xposure/lib/python3.10/site-packages/flashbax/__init__.py", line 16, in <module>
    from flashbax.buffers import (
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/__init__.py", line 16, in <module>
    from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_flat_buffer.py", line 25, in <module>
    from flashbax.buffers.prioritised_trajectory_buffer import (
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_trajectory_buffer.py", line 39, in <module>
    from flashbax.buffers import sum_tree, trajectory_buffer
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/sum_tree.py", line 33, in <module>
    from flax.struct import dataclass
  File "/xposure/lib/python3.10/site-packages/flax/__init__.py", line 24, in <module>
    from flax import core
  File "/xposure/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/xposure/lib/python3.10/site-packages/flax/core/axes_scan.py", line 23, in <module>
    from jax.extend import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax.extend' (/xposure/lib/python3.10/site-packages/jax/extend/__init__.py)

To Reproduce

Steps to reproduce the behavior:

  1. make build
  2. run ff_ddqn.py with the docker

Possible Solution

Change version of flax and jax/jaxlib in the requirements.txt and the Dockerfile

Context (Environment)

Linux 24.04 with docker.
This is the pip freeze if I run the Docker with the current setting of the repo:

absl-py==2.1.0
antlr4-python3-runtime==4.9.3
arch==7.0.0
arrow==1.3.0
attrs==23.2.0
black==24.4.2
blinker==1.8.2
boto3==1.34.140
botocore==1.34.140
bravado==11.0.3
bravado-core==6.1.1
brax==0.10.5
certifi==2024.7.4
cfgv==3.4.0
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorcet==3.0.0
contextlib2==21.6.0
contourpy==1.2.1
craftax==1.4.3
cycler==0.12.1
decorator==5.1.1
distlib==0.3.8
distrax @ git+https://github.com/google-deepmind/distrax@0e449826b6be7603a56b98dbf64873cae3aa523e
dm-env==1.6
dm-tree==0.1.8
docker-pycreds==0.4.0
dotmap==1.3.30
etils==1.7.0
evosax==0.1.6
Farama-Notifications==0.0.4
filelock==3.15.4
flashbax @ git+https://github.com/instadeepai/flashbax@1c31b526e6374620395633d1699494f104543177
Flask==3.0.3
Flask-Cors==4.0.1
flax==0.8.5
fonttools==4.53.1
fqdn==1.5.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gitdb==4.0.11
GitPython==3.1.43
glfw==2.7.0
grpcio==1.64.1
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
gymnax==0.0.8
huggingface-hub==0.23.4
hydra-core==1.3.2
id-marl-eval @ git+https://github.com/instadeepai/marl-eval@f97a72350d954e31a70531f55d8fca50db0d25f0
identify==2.5.36
idna==3.7
imageio==2.34.2
imageio-ffmpeg==0.5.1
importlib-metadata==4.13.0
importlib_resources==6.4.0
isoduration==20.11.0
itsdangerous==2.2.0
jax==0.4.13
jaxlib==0.4.13+cuda11.cudnn86
jaxmarl==0.0.2
jaxopt==0.8.3
Jinja2==3.1.4
jmespath==1.0.1
jsonpointer==3.0.0
jsonref==1.1.0
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jumanji==1.0.0
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.5
mctx==0.0.5
mdurl==0.1.2
ml-dtypes==0.4.0
ml_collections==0.1.1
monotonic==1.6
msgpack==1.0.8
mujoco==3.1.6
mujoco-mjx==3.1.6
mypy-extensions==1.0.0
neptune==1.10.4
nest-asyncio==1.6.0
nodeenv==1.9.1
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvcc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==9.2.0.82
nvidia-cufft-cu11==10.9.0.58
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
oauthlib==3.2.2
omegaconf==2.3.0
opt-einsum==3.3.0
optax @ git+https://github.com/google-deepmind/optax.git@10cf508f505acd99feac5c231c0f521895bb3a37
orbax-checkpoint==0.5.20
packaging==24.1
pandas==1.4.4
param==2.1.1
pathspec==0.12.1
patsy==0.5.6
pgx==2.0.1
pgx-minatar==0.5.1
pillow==10.4.0
platformdirs==4.2.2
pre-commit==3.7.1
protobuf==3.20.2
psutil==6.0.0
pyct==0.5.0
pygame==2.6.0
Pygments==2.18.0
PyJWT==2.8.0
PyOpenGL==3.1.7
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytinyrenderer==0.0.14
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
requests==2.32.3
requests-oauthlib==2.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rlax==0.1.6
rliable @ git+https://github.com/google-research/rliable@1171833f6706b6c25bbf042e2cb185a96fcf2ce6
rpds-py==0.18.1
s3transfer==0.10.2
safetensors==0.4.3
scipy==1.14.0
seaborn==0.13.2
sentry-sdk==2.7.1
setproctitle==1.3.3
simplejson==3.19.2
six==1.16.0
smmap==5.0.1
statsmodels==0.14.2
svgwrite==1.4.3
swagger-spec-validator==3.0.4
tdqm==0.0.1
tensorboard-logger==0.1.0
tensorboardX==2.6.2.2
tensorflow-probability==0.24.0
tensorstore==0.1.63
tomli==2.0.1
toolz==0.12.1
tqdm==4.66.4
trimesh==4.4.1
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.2
uri-template==1.3.0
urllib3==2.2.2
virtualenv==20.26.3
wandb==0.17.4
webcolors==24.6.0
websocket-client==1.8.0
Werkzeug==3.0.3
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@991f13c7885c24c82302a1ee3a68a24a29801a94
-e git+https://github.com/thomashirtz/xposure@af41ff3c6dc2262f7592831efa95a7da505e3b21#egg=xposure
zipp==3.19.2

@thomashirtz thomashirtz added the bug Something isn't working label Jul 7, 2024
@thomashirtz thomashirtz changed the title [BUG] [BUG] Jax - Flax compatibility error Jul 7, 2024
@EdanToledo
Copy link
Owner

Hmm, let me look into this. I unfortunately dont have access to a GPU machine currently so itll be hard for me to test this however regardless this reminds me to raise the jax version in the requirements file. Just make sure that the image you are pulling and the jax version has the same cuda and cudnn version and that they are aligned.

@EdanToledo
Copy link
Owner

@thomashirtz Did you ever figure out the issue?

@thomashirtz
Copy link
Author

No, unfortunately I didn't, because I don't have too much time debugging this, I stopped using docker and switch to venv

@EdanToledo EdanToledo added the good first issue Good for newcomers label Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants