Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/rainbow #86

Merged
merged 27 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6377ee0
noisy layer v_0
RPegoud May 31, 2024
aa01f11
working noisy layer
RPegoud Jun 2, 2024
ec985d4
Added noisy MLP torso
RPegoud Jun 3, 2024
8e4b1ee
noisy DQN setup
RPegoud Jun 3, 2024
1333093
corrected noise generation
RPegoud Jun 3, 2024
37e82b1
ff_dqn v0
RPegoud Jun 3, 2024
7bee362
fixed docker jax version issue, typo in requirements
RPegoud Jun 5, 2024
b05a5a2
added optional rng key to evaluator, functional version of NoisyDQN (…
RPegoud Jun 5, 2024
53560fe
n-step C51 progress
RPegoud Jun 7, 2024
6728801
trajectory buffer updates
RPegoud Jun 8, 2024
4ed21b4
batch_add for warmup step
RPegoud Jun 12, 2024
d772052
buffer seq add fix
RPegoud Jun 12, 2024
77a7415
refactoring
RPegoud Jun 12, 2024
0d84d2e
n_step_categorical_double_q_learning v0 (requires comments and testing)
RPegoud Jun 13, 2024
fb5bc0d
added comments
RPegoud Jun 14, 2024
4c42494
feat: edit rainbow to not require a new loss fn
EdanToledo Jun 14, 2024
02577b6
fix: the size of the traj buffer and change configs
EdanToledo Jun 15, 2024
afbc2b9
feat: add distributional dueling head
EdanToledo Jun 15, 2024
82b5a01
feat: added rainbow dueling head with CNN torso and Noisy layers, add…
RPegoud Jun 17, 2024
ed78f8e
Merge branch 'main' into feat/rainbow
EdanToledo Jun 17, 2024
defc593
chore: clean up code and modify dueling configs
EdanToledo Jun 17, 2024
ed1111d
feat: importance sampling exponent scheduler
RPegoud Jun 18, 2024
0b7462a
corrected the scheduler max steps parameter
RPegoud Jun 18, 2024
6f4a190
chore: small modifications
EdanToledo Jun 18, 2024
7d6890b
chore: type check attempt
RPegoud Jun 19, 2024
5b70d44
fix: make typing simpler
EdanToledo Jun 19, 2024
f262e79
chore: refactor noisy layer
EdanToledo Jun 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ __pycache__/
*.py[cod]
*$py.class

cuda-*

# C extensions
*.so

Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# FROM ubuntu:22.04 as base
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04

# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive
Expand Down Expand Up @@ -34,6 +34,6 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
# Need to use specific cuda versions for jax
ARG USE_CUDA=true
RUN if [ "$USE_CUDA" = true ] ; \
then pip install "jax[cuda11_pip]<=0.4.13" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
then pip install "jax[cuda12]>=0.4.10" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
fi
EXPOSE 6006
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ PWD := $(CURDIR)
endif

# Set flag for docker run command
BASE_FLAGS=-it --rm -v ${PWD}:/home/app/stoix -w /home/app/stoix
BASE_FLAGS=-it --rm --shm-size=1g -v ${PWD}:/home/app/stoix -w /home/app/stoix
RUN_FLAGS=$(GPUS) $(BASE_FLAGS)

DOCKER_IMAGE_NAME = stoix
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ Stoix makes use of Hydra for config management. In order to see our default syst
python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole
```

Additionally, certain implementations such as Dueling DQN are decided by the network architecture but the underlying algorithm stays the same. For example, if you wanted to run Dueling DQN you would simply do:

```bash
python stoix/systems/q_learning/ff_dqn.py network=mlp_dueling_dqn
```

or if you wanted to do dueling C51, you could do:

```bash
python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51
```

## Contributing 🤝

Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines.
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ optax @ git+https://github.com/google-deepmind/optax.git
pgx
protobuf==3.20.2
rlax
tdqm
tensorboard_logger
tensorflow_probability
tqdm
wandb
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@main
3 changes: 2 additions & 1 deletion stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]):
LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]]

ActorApply = Callable[[FrozenDict, Observation], DistributionLike]
ActorApply = Callable[..., DistributionLike]

ActFn = Callable[[FrozenDict, Observation, chex.PRNGKey], chex.Array]
CriticApply = Callable[[FrozenDict, Observation], Value]
DistributionCriticApply = Callable[[FrozenDict, Observation], DistributionLike]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- logger: base_logger
- arch: anakin
- system: ff_dqn
- network: mlp_dueling_dqn
- system: ff_rainbow
- network: mlp_noisy_dueling_c51
- env: gymnax/cartpole
- _self_
14 changes: 14 additions & 0 deletions stoix/configs/network/cnn_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.CNNTorso
channel_sizes: [32, 64, 64]
kernel_sizes: [64, 32, 16]
strides: [4, 2, 1]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
use_layer_norm: False
activation: silu
2 changes: 1 addition & 1 deletion stoix/configs/network/mlp_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
12 changes: 12 additions & 0 deletions stoix/configs/network/mlp_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.DistributionalDuelingQNetwork
layer_sizes: [128, 128]
use_layer_norm: False
activation: silu
7 changes: 5 additions & 2 deletions stoix/configs/network/mlp_dueling_dqn.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.DuelingQNetwork
layer_sizes: [128, 128]
use_layer_norm: False
activation: silu
# We dont use an action head here, because we are using a dueling network
# which has a value and advantage head.
9 changes: 9 additions & 0 deletions stoix/configs/network/mlp_noisy_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ---Noisy MLP DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.NoisyMLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
12 changes: 12 additions & 0 deletions stoix/configs/network/mlp_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---MLP Dueling DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.NoisyMLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
use_layer_norm: False
activation: silu
12 changes: 6 additions & 6 deletions stoix/configs/system/ff_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
system_name: ff_dqn # Name of the system.

# --- RL hyperparameters ---
rollout_length: 1 # Number of environment steps per vectorised environment.
epochs: 1 # Number of sgd steps per rollout.
rollout_length: 2 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 500_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 256 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
q_lr: 1e-4 # the learning rate of the Q network network optimizer
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
q_lr: 5e-4 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward : 1000.0 # maximum absolute reward value
huber_loss_parameter: 1.0 # parameter for the huber loss. If 0, it uses MSE loss.
huber_loss_parameter: 0.0 # parameter for the huber loss. If 0, it uses MSE loss.
25 changes: 25 additions & 0 deletions stoix/configs/system/ff_rainbow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# --- Defaults FF-RAINBOW ---

system_name: ff_rainbow # Name of the system.

# --- RL hyperparameters ---
rollout_length: 4 # Number of environment steps per vectorised environment.
epochs: 128 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
priority_exponent: 0.5 # exponent for the prioritised experience replay
importance_sampling_exponent: 0.4 # exponent for the importance sampling weights
n_step: 5 # how many steps in the transition to use for the n-step return
q_lr: 6.25e-5 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.0 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.0 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward: 1000.0 # maximum absolute reward value
num_atoms: 51 # number of atoms in the distributional Q network
v_min: 0.0 # minimum value of the support
v_max: 500.0 # maximum value of the support
sigma_zero: 0.25 # initialization value for noisy variance terms
11 changes: 9 additions & 2 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
from stoix.utils.jax_utils import unreplicate_batch_dim


def get_distribution_act_fn(config: DictConfig, actor_apply: ActorApply) -> ActFn:
def get_distribution_act_fn(
config: DictConfig,
actor_apply: ActorApply,
rngs: Optional[Dict[str, chex.PRNGKey]] = None,
) -> ActFn:
"""Get the act_fn for a network that returns a distribution."""

def act_fn(params: FrozenDict, observation: chex.Array, key: chex.PRNGKey) -> chex.Array:
"""Get the action from the distribution."""
pi = actor_apply(params, observation)
if rngs is None:
pi = actor_apply(params, observation)
else:
pi = actor_apply(params, observation, rngs=rngs)
if config.arch.evaluation_greedy:
action = pi.mode()
else:
Expand Down
88 changes: 86 additions & 2 deletions stoix/networks/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import chex
import distrax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen.initializers import Initializer, orthogonal

from stoix.networks.torso import MLPTorso
from stoix.networks.layers import NoisyLinear
from stoix.networks.torso import MLPTorso, NoisyMLPTorso


class DuelingQNetwork(nn.Module):
Expand All @@ -23,13 +25,18 @@ class DuelingQNetwork(nn.Module):
def __call__(self, inputs: chex.Array) -> chex.Array:

value = MLPTorso(
(*self.layer_sizes, 1), self.activation, self.use_layer_norm, self.kernel_init
(*self.layer_sizes, 1),
self.activation,
self.use_layer_norm,
self.kernel_init,
activate_final=False,
)(inputs)
advantages = MLPTorso(
(*self.layer_sizes, self.action_dim),
self.activation,
self.use_layer_norm,
self.kernel_init,
activate_final=False,
)(inputs)

# Advantages have zero mean.
Expand All @@ -38,3 +45,80 @@ def __call__(self, inputs: chex.Array) -> chex.Array:
q_values = value + advantages

return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon)


class DistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
activation: str = "relu"
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))

@nn.compact
def __call__(self, inputs: chex.Array) -> chex.Array:

value_torso = MLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.kernel_init
)(inputs)
advantages_torso = MLPTorso(
self.layer_sizes,
self.activation,
self.use_layer_norm,
self.kernel_init,
)(inputs)

value_logits = nn.Dense(self.num_atoms, kernel_init=self.kernel_init)(value_torso)
value_logits = jnp.reshape(value_logits, (-1, 1, self.num_atoms))
adv_logits = nn.Dense(self.action_dim * self.num_atoms, kernel_init=self.kernel_init)(
advantages_torso
)
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
atoms = jnp.broadcast_to(atoms, (q_values.shape[0], self.num_atoms))
return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_logits, atoms


class NoisyDistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
sigma_zero: float
activation: str = "relu"
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))

@nn.compact
def __call__(self, embeddings: chex.Array) -> chex.Array:
value_torso = NoisyMLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.sigma_zero
)(embeddings)
advantages_torso = NoisyMLPTorso(
self.layer_sizes, self.activation, self.use_layer_norm, self.sigma_zero
)(embeddings)

value_logits = NoisyLinear(self.num_atoms, sigma_zero=self.sigma_zero)(value_torso)
value_logits = jnp.reshape(value_logits, (-1, 1, self.num_atoms))
adv_logits = NoisyLinear(self.action_dim * self.num_atoms, sigma_zero=self.sigma_zero)(
advantages_torso
)
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
atoms = jnp.broadcast_to(atoms, (q_values.shape[0], self.num_atoms))
return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_logits, atoms
Loading
Loading