If you use JAXRL2 in your work, please cite this repository in publications:
@misc{jaxrl,
author = {Kostrikov, Ilya},
doi = {10.5281/zenodo.5535154},
month = {10},
title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}},
url = {https://github.com/ikostrikov/jaxrl2},
year = {2022},
note = {v2}
}
Run
pip install --upgrade pip
pip install -e .
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
See instructions for other versions of CUDA here.
MUJOCO_GL=egl CUDA_VISIBLE_DEVICES= pytest tests
Thanks to @evgenii-nikishin for helping with JAX. And @dibyaghosh for helping with vmapped ensembles.