Skip to content

Commit

Permalink
Merge pull request #94 from EdanToledo/chore/edit_dist_network_configs
Browse files Browse the repository at this point in the history
chore: move input of distributional network args into config
  • Loading branch information
EdanToledo authored Jun 20, 2024
2 parents 06c4003 + 16c18fb commit bcd0b1e
Show file tree
Hide file tree
Showing 18 changed files with 38 additions and 43 deletions.
4 changes: 4 additions & 0 deletions stoix/configs/network/cnn_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ actor_network:
layer_sizes: [512]
use_layer_norm: False
activation: silu
vmin: ${system.vmin}
vmax: ${system.vmax}
num_atoms: ${system.num_atoms}
sigma_zero: ${system.sigma_zero}
3 changes: 3 additions & 0 deletions stoix/configs/network/mlp_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ actor_network:
activation: silu
action_head:
_target_: stoix.networks.heads.DistributionalDiscreteQNetwork
vmin: ${system.vmin}
vmax: ${system.vmax}
num_atoms: ${system.num_atoms}
3 changes: 3 additions & 0 deletions stoix/configs/network/mlp_d4pg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ q_network:
activation: silu
critic_head:
_target_: stoix.networks.heads.DistributionalContinuousQNetwork
vmin: ${system.vmin}
vmax: ${system.vmax}
num_atoms: ${system.num_atoms}
3 changes: 3 additions & 0 deletions stoix/configs/network/mlp_dueling_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ actor_network:
layer_sizes: [128, 128]
use_layer_norm: False
activation: silu
vmin: ${system.vmin}
vmax: ${system.vmax}
num_atoms: ${system.num_atoms}
1 change: 1 addition & 0 deletions stoix/configs/network/mlp_noisy_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ actor_network:
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
sigma_zero: ${system.sigma_zero}
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
5 changes: 5 additions & 0 deletions stoix/configs/network/mlp_noisy_dueling_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ actor_network:
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
sigma_zero: ${system.sigma_zero}
action_head:
_target_: stoix.networks.dueling.NoisyDistributionalDuelingQNetwork
layer_sizes: [512]
use_layer_norm: False
activation: silu
vmin: ${system.vmin}
vmax: ${system.vmax}
num_atoms: ${system.num_atoms}
sigma_zero: ${system.sigma_zero}
1 change: 1 addition & 0 deletions stoix/configs/network/mlp_qr_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ actor_network:
activation: silu
action_head:
_target_: stoix.networks.heads.QuantileDiscreteQNetwork
num_quantiles: ${system.num_quantiles}
4 changes: 2 additions & 2 deletions stoix/configs/system/ff_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward : 1000.0 # maximum absolute reward value
num_atoms: 51 # number of atoms in the distributional Q network
v_min: 0.0 # minimum value of the support
v_max: 500.0 # maximum value of the support
vmin: 0.0 # minimum value of the support
vmax: 500.0 # maximum value of the support
4 changes: 2 additions & 2 deletions stoix/configs/system/ff_d4pg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
num_atoms: 301 # number of atoms in the distributional Q network
v_min: -9_000.0 # minimum value of the support
v_max: 9_000.0 # maximum value of the support
vmin: -9_000.0 # minimum value of the support
vmax: 9_000.0 # maximum value of the support
exploration_noise : 0.15 # standard deviation of the exploration noise
n_step: 5 # how many steps in the trajectory to use for the n-step return
4 changes: 2 additions & 2 deletions stoix/configs/system/ff_rainbow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ training_epsilon: 0.0 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.0 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward: 1000.0 # maximum absolute reward value
num_atoms: 51 # number of atoms in the distributional Q network
v_min: 0.0 # minimum value of the support
v_max: 500.0 # maximum value of the support
vmin: 0.0 # minimum value of the support
vmax: 500.0 # maximum value of the support
sigma_zero: 0.25 # initialization value for noisy variance terms
12 changes: 6 additions & 6 deletions stoix/networks/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __call__(self, inputs: chex.Array) -> chex.Array:

class DistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
vmax: float
vmin: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
Expand Down Expand Up @@ -79,7 +79,7 @@ def __call__(self, inputs: chex.Array) -> chex.Array:
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
atoms = jnp.linspace(self.vmin, self.vmax, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
Expand All @@ -89,8 +89,8 @@ def __call__(self, inputs: chex.Array) -> chex.Array:

class NoisyDistributionalDuelingQNetwork(nn.Module):
num_atoms: int
v_max: float
v_min: float
vmax: float
vmin: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
Expand All @@ -116,7 +116,7 @@ def __call__(self, embeddings: chex.Array) -> chex.Array:
adv_logits = jnp.reshape(adv_logits, (-1, self.action_dim, self.num_atoms))
q_logits = value_logits + adv_logits - adv_logits.mean(axis=1, keepdims=True)

atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
atoms = jnp.linspace(self.vmin, self.vmax, self.num_atoms)
q_dist = jax.nn.softmax(q_logits)
q_values = jnp.sum(q_dist * atoms, axis=2)
q_values = jax.lax.stop_gradient(q_values)
Expand Down
12 changes: 6 additions & 6 deletions stoix/networks/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,15 @@ class DistributionalDiscreteQNetwork(nn.Module):
action_dim: int
epsilon: float
num_atoms: int
v_min: float
v_max: float
vmin: float
vmax: float
kernel_init: Initializer = lecun_normal()

@nn.compact
def __call__(
self, embedding: chex.Array
) -> Tuple[distrax.EpsilonGreedy, chex.Array, chex.Array]:
atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
atoms = jnp.linspace(self.vmin, self.vmax, self.num_atoms)
q_logits = nn.Dense(self.action_dim * self.num_atoms, kernel_init=self.kernel_init)(
embedding
)
Expand All @@ -252,15 +252,15 @@ def __call__(

class DistributionalContinuousQNetwork(nn.Module):
num_atoms: int
v_min: float
v_max: float
vmin: float
vmax: float
kernel_init: Initializer = lecun_normal()

@nn.compact
def __call__(
self, embedding: chex.Array
) -> Tuple[distrax.EpsilonGreedy, chex.Array, chex.Array]:
atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
atoms = jnp.linspace(self.vmin, self.vmax, self.num_atoms)
q_logits = nn.Dense(self.num_atoms, kernel_init=self.kernel_init)(embedding)
q_dist = jax.nn.softmax(q_logits)
q_value = jnp.sum(q_dist * atoms, axis=-1)
Expand Down
3 changes: 0 additions & 3 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,6 @@ def learner_setup(
q_network_torso = hydra.utils.instantiate(config.network.q_network.pre_torso)
q_network_head = hydra.utils.instantiate(
config.network.q_network.critic_head,
num_atoms=config.system.num_atoms,
v_min=config.system.v_min,
v_max=config.system.v_max,
)
q_network = CompositeNetwork([q_network_input, q_network_torso, q_network_head])

Expand Down
6 changes: 0 additions & 6 deletions stoix/systems/q_learning/ff_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.training_epsilon,
num_atoms=config.system.num_atoms,
v_min=config.system.v_min,
v_max=config.system.v_max,
)

q_network = Actor(torso=q_network_torso, action_head=q_network_action_head)
Expand All @@ -301,9 +298,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.evaluation_epsilon,
num_atoms=config.system.num_atoms,
v_min=config.system.v_min,
v_max=config.system.v_max,
)
eval_q_network = Actor(torso=q_network_torso, action_head=eval_q_network_action_head)
eval_q_network = EvalActorWrapper(actor=eval_q_network)
Expand Down
2 changes: 0 additions & 2 deletions stoix/systems/q_learning/ff_qr_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.training_epsilon,
num_quantiles=config.system.num_quantiles,
)

q_network = Actor(torso=q_network_torso, action_head=q_network_action_head)
Expand All @@ -313,7 +312,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.evaluation_epsilon,
num_quantiles=config.system.num_quantiles,
)
eval_q_network = Actor(torso=q_network_torso, action_head=eval_q_network_action_head)
eval_q_network = EvalActorWrapper(actor=eval_q_network)
Expand Down
8 changes: 0 additions & 8 deletions stoix/systems/q_learning/ff_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.training_epsilon,
sigma_zero=config.system.sigma_zero,
num_atoms=config.system.num_atoms,
v_min=config.system.v_min,
v_max=config.system.v_max,
)

q_network = Actor(torso=q_network_torso, action_head=q_network_action_head)
Expand All @@ -367,10 +363,6 @@ def learner_setup(
config.network.actor_network.action_head,
action_dim=action_dim,
epsilon=config.system.evaluation_epsilon,
sigma_zero=config.system.sigma_zero,
num_atoms=config.system.num_atoms,
v_min=config.system.v_min,
v_max=config.system.v_max,
)
eval_q_network = Actor(torso=q_network_torso, action_head=eval_q_network_action_head)
eval_q_network = EvalActorWrapper(actor=eval_q_network)
Expand Down
3 changes: 0 additions & 3 deletions stoix/systems/search/ff_mz.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,6 @@ def learner_setup(
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_head = hydra.utils.instantiate(
config.network.critic_network.critic_head,
num_atoms=config.system.critic_num_atoms,
vmin=config.system.critic_vmin,
vmax=config.system.critic_vmax,
)

actor_network = Actor(
Expand Down
3 changes: 0 additions & 3 deletions stoix/systems/search/ff_sampled_mz.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,6 @@ def learner_setup(
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_head = hydra.utils.instantiate(
config.network.critic_network.critic_head,
num_atoms=config.system.critic_num_atoms,
vmin=config.system.critic_vmin,
vmax=config.system.critic_vmax,
)

actor_network = Actor(
Expand Down

0 comments on commit bcd0b1e

Please sign in to comment.