From 95db9805a3eedc7925f31da6f32b7d9836ec63bf Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 4 Aug 2024 16:38:43 +0200 Subject: [PATCH] Avoid buggy version of JAX in CI Related to https://github.com/google/jax/issues/22751 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 674bc52c7b..e4ff6724d6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -156,7 +156,7 @@ jobs: micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi - if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi + if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "jax!=0.4.31" "jaxlib!=0.4.31" numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi pip install pytest-sphinx