Skip to content

Commit

Permalink
feat: add MLP torso after CNN torso
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Sep 1, 2024
1 parent 5284124 commit 65b049b
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 10 deletions.
2 changes: 2 additions & 0 deletions stoix/configs/network/cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ actor_network:
use_layer_norm: False
activation: silu
channel_first: True
hidden_sizes: [128, 128]
action_head:
_target_: stoix.networks.heads.CategoricalHead

Expand All @@ -20,5 +21,6 @@ critic_network:
use_layer_norm: False
activation: silu
channel_first: True
hidden_sizes: [128, 128]
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
2 changes: 2 additions & 0 deletions stoix/configs/network/cnn_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ actor_network:
strides: [4, 2, 1]
use_layer_norm: False
activation: silu
channel_first: True
hidden_sizes: [128, 128]
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
Expand Down
2 changes: 2 additions & 0 deletions stoix/configs/network/visual_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ actor_network:
use_layer_norm: False
activation: silu
channel_first: True
hidden_sizes: [128, 128]
action_head:
_target_: stoix.networks.heads.CategoricalHead

Expand All @@ -18,5 +19,6 @@ critic_network:
use_layer_norm: False
activation: silu
channel_first: True
hidden_sizes: [128, 128]
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
37 changes: 29 additions & 8 deletions stoix/networks/torso.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def __call__(self, observation: chex.Array) -> chex.Array:
"""Forward pass."""
x = observation
for layer_size in self.layer_sizes:
x = nn.Dense(layer_size, kernel_init=self.kernel_init)(x)
x = nn.Dense(
layer_size, kernel_init=self.kernel_init, use_bias=not self.use_layer_norm
)(x)
if self.use_layer_norm:
x = nn.LayerNorm(use_scale=False)(x)
x = nn.LayerNorm()(x)
if self.activate_final or layer_size != self.layer_sizes[-1]:
x = parse_activation_fn(self.activation)(x)
return x
Expand All @@ -45,17 +47,20 @@ class NoisyMLPTorso(nn.Module):
def __call__(self, observation: chex.Array) -> chex.Array:
x = observation
for layer_size in self.layer_sizes:
x = NoisyLinear(layer_size, sigma_zero=self.sigma_zero)(x)
x = NoisyLinear(
layer_size, sigma_zero=self.sigma_zero, use_bias=not self.use_layer_norm
)(x)
if self.use_layer_norm:
x = nn.LayerNorm(use_scale=False)(x)
x = nn.LayerNorm()(x)
if self.activate_final or layer_size != self.layer_sizes[-1]:
x = parse_activation_fn(self.activation)(x)
return x


class CNNTorso(nn.Module):
"""2D CNN torso. Expects input of shape (batch, height, width, channels).
After this torso, the output is flattened."""
After this torso, the output is flattened and put through an MLP of
hidden_sizes."""

channel_sizes: Sequence[int]
kernel_sizes: Sequence[int]
Expand All @@ -64,6 +69,7 @@ class CNNTorso(nn.Module):
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))
channel_first: bool = False
hidden_sizes: Sequence[int] = (256,)

@nn.compact
def __call__(self, observation: chex.Array) -> chex.Array:
Expand All @@ -72,10 +78,25 @@ def __call__(self, observation: chex.Array) -> chex.Array:
# Move channels to the last dimension if they are first
if self.channel_first:
x = x.transpose((0, 2, 3, 1))
# Convolutional layers
for channel, kernel, stride in zip(self.channel_sizes, self.kernel_sizes, self.strides):
x = nn.Conv(channel, (kernel, kernel), (stride, stride))(x)
x = nn.Conv(
channel, (kernel, kernel), (stride, stride), use_bias=not self.use_layer_norm
)(x)
if self.use_layer_norm:
x = nn.LayerNorm(use_scale=False)(x)
x = nn.LayerNorm(reduction_axes=(-3, -2, -1))(x)
x = parse_activation_fn(self.activation)(x)

return x.reshape(*observation.shape[:-3], -1)
# Flatten
x = x.reshape(*observation.shape[:-3], -1)

# MLP layers
x = MLPTorso(
layer_sizes=self.hidden_sizes,
activation=self.activation,
use_layer_norm=self.use_layer_norm,
kernel_init=self.kernel_init,
activate_final=True,
)(x)

return x
15 changes: 13 additions & 2 deletions stoix/utils/env_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
import threading
from typing import Any

import envpool
from colorama import Fore, Style

# Envpool is not usable on certain platforms, so we need to handle the ImportError
try:
import envpool
except ImportError:
envpool = None
print(
f"{Fore.MAGENTA}{Style.BRIGHT}Envpool not installed. "
f"Please install it to use the Envpool factory{Style.RESET_ALL}"
)

import gymnasium

from stoix.wrappers.envpool import EnvPoolToJumanji
Expand Down Expand Up @@ -43,7 +54,7 @@ def __call__(self, num_envs: int) -> Any:
num_envs=num_envs,
seed=seed,
gym_reset_return_info=True,
**self.kwargs
**self.kwargs,
)
)

Expand Down

0 comments on commit 65b049b

Please sign in to comment.