Skip to content

Commit

Permalink
Merge pull request #86 from RPegoud/feat/rainbow
Browse files Browse the repository at this point in the history
Feat/rainbow
  • Loading branch information
EdanToledo authored Jun 20, 2024
2 parents 0a05daa + f262e79 commit a2b453c
Show file tree
Hide file tree
Showing 23 changed files with 524 additions and 88 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ __pycache__/
*.py[cod]
*$py.class

cuda-*

# C extensions
*.so

Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# FROM ubuntu:22.04 as base
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04

# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive
Expand Down Expand Up @@ -34,6 +34,6 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
# Need to use specific cuda versions for jax
ARG USE_CUDA=true
RUN if [ "$USE_CUDA" = true ] ; \
then pip install "jax[cuda11_pip]<=0.4.13" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
then pip install "jax[cuda12]>=0.4.10" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
fi
EXPOSE 6006
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ PWD := $(CURDIR)
endif

# Set flag for docker run command
BASE_FLAGS=-it --rm -v ${PWD}:/home/app/stoix -w /home/app/stoix
BASE_FLAGS=-it --rm --shm-size=1g -v ${PWD}:/home/app/stoix -w /home/app/stoix
RUN_FLAGS=$(GPUS) $(BASE_FLAGS)

DOCKER_IMAGE_NAME = stoix
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ Stoix makes use of Hydra for config management. In order to see our default syst
python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole
```

Additionally, certain implementations such as Dueling DQN are decided by the network architecture but the underlying algorithm stays the same. For example, if you wanted to run Dueling DQN you would simply do:

```bash
python stoix/systems/q_learning/ff_dqn.py network=mlp_dueling_dqn
```

or if you wanted to do dueling C51, you could do:

```bash
python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51
```

## Contributing 🤝

Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines.
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ optax @ git+https://github.com/google-deepmind/optax.git
pgx
protobuf==3.20.2
rlax
tdqm
tensorboard_logger
tensorflow_probability
tqdm
wandb
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@main
3 changes: 2 additions & 1 deletion stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]):
LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]]

ActorApply = Callable[[FrozenDict, Observation], DistributionLike]
ActorApply = Callable[..., DistributionLike]

ActFn = Callable[[FrozenDict, Observation, chex.PRNGKey], chex.Array]
CriticApply = Callable[[FrozenDict, Observation], Value]
DistributionCriticApply = Callable[[FrozenDict, Observation], DistributionLike]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- logger: base_logger
- arch: anakin
- system: ff_dqn
- network: mlp_dueling_dqn
- system: ff_rainbow
- network: mlp_noisy_dueling_c51
- env: gymnax/cartpole
- _self_
14 changes: 14 additions & 0 deletions stoix/configs/network/cnn_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.CNNTorso
channel_sizes: [32, 64, 64]
kernel_sizes: [64, 32, 16]
strides: [4, 2, 1]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
use_layer_norm: False
activation: silu
2 changes: 1 addition & 1 deletion stoix/configs/network/mlp_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
12 changes: 12 additions & 0 deletions stoix/configs/network/mlp_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.DistributionalDuelingQNetwork
layer_sizes: [128, 128]
use_layer_norm: False
activation: silu
7 changes: 5 additions & 2 deletions stoix/configs/network/mlp_dueling_dqn.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.DuelingQNetwork
layer_sizes: [128, 128]
use_layer_norm: False
activation: silu
# We dont use an action head here, because we are using a dueling network
# which has a value and advantage head.
9 changes: 9 additions & 0 deletions stoix/configs/network/mlp_noisy_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ---Noisy MLP DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.NoisyMLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
12 changes: 12 additions & 0 deletions stoix/configs/network/mlp_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.NoisyMLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
use_layer_norm: False
activation: silu
12 changes: 6 additions & 6 deletions stoix/configs/system/ff_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
system_name: ff_dqn # Name of the system.

# --- RL hyperparameters ---
rollout_length: 1 # Number of environment steps per vectorised environment.
epochs: 1 # Number of sgd steps per rollout.
rollout_length: 2 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 500_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 256 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
q_lr: 1e-4 # the learning rate of the Q network network optimizer
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
q_lr: 5e-4 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward : 1000.0 # maximum absolute reward value
huber_loss_parameter: 1.0 # parameter for the huber loss. If 0, it uses MSE loss.
huber_loss_parameter: 0.0 # parameter for the huber loss. If 0, it uses MSE loss.
25 changes: 25 additions & 0 deletions stoix/configs/system/ff_rainbow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# --- Defaults FF-RAINBOW ---

system_name: ff_rainbow # Name of the system.

# --- RL hyperparameters ---
rollout_length: 4 # Number of environment steps per vectorised environment.
epochs: 128 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
priority_exponent: 0.5 # exponent for the prioritised experience replay
importance_sampling_exponent: 0.4 # exponent for the importance sampling weights
n_step: 5 # how many steps in the transition to use for the n-step return
q_lr: 6.25e-5 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.0 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.0 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward: 1000.0 # maximum absolute reward value
num_atoms: 51 # number of atoms in the distributional Q network
v_min: 0.0 # minimum value of the support
v_max: 500.0 # maximum value of the support
sigma_zero: 0.25 # initialization value for noisy variance terms
11 changes: 9 additions & 2 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
from stoix.utils.jax_utils import unreplicate_batch_dim


def get_distribution_act_fn(config: DictConfig, actor_apply: ActorApply) -> ActFn:
def get_distribution_act_fn(
config: DictConfig,
actor_apply: ActorApply,
rngs: Optional[Dict[str, chex.PRNGKey]] = None,
) -> ActFn:
"""Get the act_fn for a network that returns a distribution."""

def act_fn(params: FrozenDict, observation: chex.Array, key: chex.PRNGKey) -> chex.Array:
"""Get the action from the distribution."""
pi = actor_apply(params, observation)
if rngs is None:
pi = actor_apply(params, observation)
else:
pi = actor_apply(params, observation, rngs=rngs)
if config.arch.evaluation_greedy:
action = pi.mode()
else:
Expand Down
88 changes: 86 additions & 2 deletions stoix/networks/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import chex
import distrax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen.initializers import Initializer, orthogonal

from stoix.networks.torso import MLPTorso
from stoix.networks.layers import NoisyLinear
from stoix.networks.torso import MLPTorso, NoisyMLPTorso


class DuelingQNetwork(nn.Module):
Expand All @@ -23,13 +25,18 @@ class DuelingQNetwork(nn.Module):
def __call__(self, inputs: chex.Array) -> chex.Array:

value = MLPTorso(
(*self.layer_sizes, 1), self.activation, self.use_layer_norm, self.kernel_init
(*self.layer_sizes, 1),
self.activation,
self.use_layer_norm,
self.kernel_init,
activate_final=False,
)(inputs)
advantages = MLPTorso(
(*self.layer_sizes, self.action_dim),
self.activation,
self.use_layer_norm,
self.kernel_init,
activate_final=False,
)(inputs)

# Advantages have zero mean.
Expand All @@ -38,3 +45,80 @@ def __call__(self, inputs: chex.Array) -> chex.Array:
q_values = value + advantages

return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon)


class DistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
activation: str = "relu"
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))

@nn.compact
def __call__(self, inputs: chex.Array) -> chex.Array:

value_torso = MLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.kernel_init
)(inputs)
advantages_torso = MLPTorso(
self.layer_sizes,
self.activation,
self.use_layer_norm,
self.kernel_init,
)(inputs)

value_logits = nn.Dense(self.num_atoms, kernel_init=self.kernel_init)(value_torso)
value_logits = jnp.reshape(value_logits, (-1, 1, self.num_atoms))
adv_logits = nn.Dense(self.action_dim * self.num_atoms, kernel_init=self.kernel_init)(
advantages_torso
)
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
atoms = jnp.broadcast_to(atoms, (q_values.shape[0], self.num_atoms))
return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_logits, atoms


class NoisyDistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
sigma_zero: float
activation: str = "relu"
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))

@nn.compact
def __call__(self, embeddings: chex.Array) -> chex.Array:
value_torso = NoisyMLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.sigma_zero
)(embeddings)
advantages_torso = NoisyMLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.sigma_zero
)(embeddings)

value_logits = NoisyLinear(self.num_atoms, sigma_zero=self.sigma_zero)(value_torso)
value_logits = jnp.reshape(value_logits, (-1, 1, self.num_atoms))
adv_logits = NoisyLinear(self.action_dim * self.num_atoms, sigma_zero=self.sigma_zero)(
advantages_torso
)
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
atoms = jnp.broadcast_to(atoms, (q_values.shape[0], self.num_atoms))
return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_logits, atoms
Loading

0 comments on commit a2b453c

Please sign in to comment.