Quickstart | Install | Examples | Docs | The JAX ecosystem | Contribute | Cite
NAVIX is a JAX-powered reimplementation of minigrid. Experiments that took 1 week, now take 15 minutes.
Key features:
- Performance Boost: NAVIX offers over 1000x speed increase compared to the original Minigrid implementation, enabling faster experimentation and scaling. You can see a preliminary performance comparison here, and a full benchmarking at here.
- XLA Compilation: Leverage the power of XLA to optimize NAVIX computations for many accelerators. NAVIX can run on CPU, GPU, and TPU.
- Autograd Support: Differentiate through environment transitions, opening up new possibilities such as learned world models.
The library is in active development, and we are working on adding more environments and features. If you want join the development and contribute, please open a discussion and let's have a chat!
Follow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation.
pip install navix
Or, for the latest version from source:
pip install git+https://github.com/epignatelli/navix
You can view a full set of examples here (more coming), but here are the most common use cases.
import jax
import navix as nx
import jax.numpy as jnp
def run(seed):
env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment
key = jax.random.PRNGKey(seed)
timestep = env.reset(key)
actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n)
def body_fun(timestep, action):
timestep = env.step(action) # Update the environment state
return timestep, ()
return jax.lax.scan(body_fun, timestep, actions)[0]
# Compile the entire training run for maximum performance
final_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000))
import jax
import navix as nx
import jax.numpy as jnp
from jax import random
def run_episode(seed, env, policy):
"""Simulates a single episode with a given policy"""
key = random.PRNGKey(seed)
timestep = env.reset(key)
done = False
total_reward = 0
while not done:
action = policy(timestep.observation)
timestep, reward, done, _ = env.step(action)
total_reward += reward
return total_reward
def train_policy(policy, num_episodes):
"""Trains a policy over multiple parallel episodes"""
envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes)
seeds = random.split(random.PRNGKey(0), num_episodes)
# Compile the entire training loop with XLA
compiled_episode = jax.jit(run_episode)
compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None)))
for _ in range(num_episodes):
rewards = compiled_train(seeds, envs, policy)
# ... Update the policy based on rewards ...
# Hypothetical policy function
def policy(observation):
# ... your policy logic ...
return action
# Start the training
train_policy(policy, num_episodes=100)
import jax
import navix as nx
import jax.numpy as jnp
from jax import grad
from flax import struct
class Model(struct.PyTreeNode):
@nn.compact
def __call__(self, x):
# ... your NN here
model = Model()
env = nx.environments.Room(16, 16, 8)
def loss(params, timestep):
action = jnp.asarray(0)
pred_obs = model.apply(timestep.observation)
timestep = env.step(timestep, action)
return jnp.square(timestep.observation - pred_obs).mean()
key = jax.random.PRNGKey(0)
timestep = env.reset(key)
params = model.init(key, timestep.observation)
gradients = grad(loss)(params, timestep)
NAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check out the following projects:
- Environments:
- Gymnax: a broad range of RL environments
- Brax: a physics engine for robotics experiments
- EnvPool: a set of various batched environments
- Craftax: a JAX reimplementation of the game of Crafter
- Jumanji: another set of diverse environments
- PGX: board games commonly used for RL, such as backgammon, chess, shogi, and go
- JAX-MARL: multi-agent RL environments in JAX
- Xland-Minigrid: a set of JAX-reimplemented grid-world environments
- Minimax: a JAX library for RL autocurricula with 120x faster baselines
- Agents:
- PureJaxRl: proposing fullly-jitten training routines
- Rejax: a suite of diverse agents, among which, DDPG, DQN, PPO, SAC, TD3
- Stoix: useful implementations of popular single-agent RL algorithms in JAX
- JAX-CORL: lean single-file implementations of offline RL algorithms with solid performance reports
- Dopamine: a research framework for fast prototyping of reinforcement learning algorithms
NAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request.
Please, consider starring the project if you like NAVIX!
If you use navix
please cite it as:
@misc{pignatelli2023navix,
author = {Pignatelli, Eduardo},
title = {Navix: Scaling gridworld navigation with JAX},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/navix}}
}