Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/refactor loss metrics #61

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,8 @@ def _actor_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"],
"value_loss": q_loss_info["q_loss"],
"actor_loss": actor_loss_info["actor_loss"],
**actor_loss_info,
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
5 changes: 2 additions & 3 deletions stoix/systems/ddpg/ff_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,8 @@ def _actor_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"],
"value_loss": q_loss_info["q_loss"],
"actor_loss": actor_loss_info["actor_loss"],
**actor_loss_info,
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
32 changes: 16 additions & 16 deletions stoix/systems/ddpg/ff_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _actor_loss_fn(
}
return actor_loss, loss_info

params, opt_states, buffer_state, key, epoch_counter = update_state
params, opt_states, buffer_state, key = update_state

key, sample_key = jax.random.split(key, num=2)

Expand Down Expand Up @@ -305,31 +305,24 @@ def _actor_loss_fn(
q_new_params = QsAndTarget(q_new_online_params, new_target_q_params)

# PACK NEW PARAMS AND OPTIMISER STATE
# Delayed policy updates
time_to_update = jnp.mod(epoch_counter, config.system.policy_frequency) == 0
actor_new_params = jax.lax.cond(
time_to_update, lambda _: actor_new_params, lambda _: params.actor_params, None
)
new_params = DDPGParams(actor_new_params, q_new_params)
new_opt_state = DDPGOptStates(actor_new_opt_state, q_new_opt_state)

# PACK LOSS INFO
loss_info = {
"total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"],
"value_loss": q_loss_info["q_loss"],
"actor_loss": actor_loss_info["actor_loss"],
**actor_loss_info,
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key, epoch_counter + 1), loss_info
return (new_params, new_opt_state, buffer_state, key), loss_info

epoch_counter = jnp.array(0)
update_state = (params, opt_states, buffer_state, key, epoch_counter)
update_state = (params, opt_states, buffer_state, key)

# UPDATE EPOCHS
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.system.epochs
)

params, opt_states, buffer_state, key, epoch_counter = update_state
params, opt_states, buffer_state, key = update_state
learner_state = DDPGLearnerState(
params, opt_states, buffer_state, key, env_state, last_timestep
)
Expand Down Expand Up @@ -398,9 +391,16 @@ def learner_setup(
actor_lr = make_learning_rate(config.system.actor_lr, config, config.system.epochs)
q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs)

actor_optim = optax.chain(
optax.clip_by_global_norm(config.system.max_grad_norm),
optax.adam(actor_lr, eps=1e-5),
def delayed_policy_update(step_count: int) -> bool:
should_update: bool = jnp.mod(step_count, config.system.policy_frequency) == 0
return should_update

actor_optim = optax.conditionally_mask(
optax.chain(
optax.clip_by_global_norm(config.system.max_grad_norm),
optax.adam(actor_lr, eps=1e-5),
),
should_transform_fn=delayed_policy_update,
)
q_optim = optax.chain(
optax.clip_by_global_norm(config.system.max_grad_norm),
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,10 @@ def _q_loss_fn(
new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state)

# PACK LOSS INFO
loss_info = actor_loss_info._asdict()
actor_loss_info = actor_loss_info._asdict()
loss_info = {
**loss_info,
"value_loss": q_loss_info["q_loss"],
**actor_loss_info,
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_mpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ def _q_loss_fn(
new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state)

# PACK LOSS INFO
loss_info = actor_loss_info._asdict()
actor_loss_info = actor_loss_info._asdict()
loss_info = {
**loss_info,
"value_loss": q_loss_info["q_loss"],
**actor_loss_info,
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
29 changes: 15 additions & 14 deletions stoix/systems/ppo/ff_dpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def _actor_loss_fn(
entropy = actor_policy.entropy(seed=rng_key).mean()

total_loss_actor = loss_actor - config.system.ent_coef * entropy
return total_loss_actor, (loss_actor, entropy)
loss_info = {
"actor_loss": loss_actor,
"entropy": entropy,
}
return total_loss_actor, loss_info

def _critic_loss_fn(
critic_params: FrozenDict,
Expand All @@ -169,21 +173,24 @@ def _critic_loss_fn(
)

critic_total_loss = config.system.vf_coef * value_loss
return critic_total_loss, (value_loss)
loss_info = {
"value_loss": value_loss,
}
return critic_total_loss, loss_info

# CALCULATE ACTOR LOSS
key, actor_loss_key = jax.random.split(key)
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
actor_grads, actor_loss_info = actor_grad_fn(
params.actor_params,
traj_batch,
advantages,
actor_loss_key,
)

# CALCULATE CRITIC LOSS
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
critic_grads, critic_loss_info = critic_grad_fn(
params.critic_params, traj_batch, targets
)

Expand Down Expand Up @@ -224,15 +231,9 @@ def _critic_loss_fn(
new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state)

# PACK LOSS INFO
total_loss = actor_loss_info[0] + critic_loss_info[0]
value_loss = critic_loss_info[1]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
"actor_loss": actor_loss,
"entropy": entropy,
**actor_loss_info,
**critic_loss_info,
}
return (new_params, new_opt_state, key), loss_info

Expand Down
29 changes: 15 additions & 14 deletions stoix/systems/ppo/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def _actor_loss_fn(
entropy = actor_policy.entropy().mean()

total_loss_actor = loss_actor - config.system.ent_coef * entropy
return total_loss_actor, (loss_actor, entropy)
loss_info = {
"actor_loss": loss_actor,
"entropy": entropy,
}
return total_loss_actor, loss_info

def _critic_loss_fn(
critic_params: FrozenDict,
Expand All @@ -168,17 +172,20 @@ def _critic_loss_fn(
)

critic_total_loss = config.system.vf_coef * value_loss
return critic_total_loss, (value_loss)
loss_info = {
"value_loss": value_loss,
}
return critic_total_loss, loss_info

# CALCULATE ACTOR LOSS
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
actor_grads, actor_loss_info = actor_grad_fn(
params.actor_params, traj_batch, advantages
)

# CALCULATE CRITIC LOSS
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
critic_grads, critic_loss_info = critic_grad_fn(
params.critic_params, traj_batch, targets
)

Expand Down Expand Up @@ -219,15 +226,9 @@ def _critic_loss_fn(
new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state)

# PACK LOSS INFO
total_loss = actor_loss_info[0] + critic_loss_info[0]
value_loss = critic_loss_info[1]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
"actor_loss": actor_loss,
"entropy": entropy,
**actor_loss_info,
**critic_loss_info,
}
return (new_params, new_opt_state), loss_info

Expand Down
30 changes: 16 additions & 14 deletions stoix/systems/ppo/ff_ppo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def _actor_loss_fn(
entropy = actor_policy.entropy(seed=rng_key).mean()

total_loss_actor = loss_actor - config.system.ent_coef * entropy
return total_loss_actor, (loss_actor, entropy)
loss_info = {
"actor_loss": loss_actor,
"entropy": entropy,
}

return total_loss_actor, loss_info

def _critic_loss_fn(
critic_params: FrozenDict,
Expand All @@ -169,21 +174,24 @@ def _critic_loss_fn(
)

critic_total_loss = config.system.vf_coef * value_loss
return critic_total_loss, (value_loss)
loss_info = {
"value_loss": value_loss,
}
return critic_total_loss, loss_info

# CALCULATE ACTOR LOSS
key, actor_loss_key = jax.random.split(key)
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
actor_grads, actor_loss_info = actor_grad_fn(
params.actor_params,
traj_batch,
advantages,
actor_loss_key,
)

# CALCULATE CRITIC LOSS
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
critic_grads, critic_loss_info = critic_grad_fn(
params.critic_params, traj_batch, targets
)

Expand Down Expand Up @@ -224,15 +232,9 @@ def _critic_loss_fn(
new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state)

# PACK LOSS INFO
total_loss = actor_loss_info[0] + critic_loss_info[0]
value_loss = critic_loss_info[1]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
"actor_loss": actor_loss,
"entropy": entropy,
**actor_loss_info,
**critic_loss_info,
}
return (new_params, new_opt_state, key), loss_info

Expand Down
29 changes: 15 additions & 14 deletions stoix/systems/ppo/rec_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def _actor_loss_fn(
entropy = actor_policy.entropy().mean()

total_loss = loss_actor - config.system.ent_coef * entropy
return total_loss, (loss_actor, entropy)
loss_info = {
"actor_loss": loss_actor,
"entropy": entropy,
}
return total_loss, loss_info

def _critic_loss_fn(
critic_params: FrozenDict,
Expand All @@ -244,17 +248,20 @@ def _critic_loss_fn(
)

total_loss = config.system.vf_coef * value_loss
return total_loss, (value_loss)
loss_info = {
"value_loss": value_loss,
}
return total_loss, loss_info

# CALCULATE ACTOR LOSS
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
actor_grads, actor_loss_info = actor_grad_fn(
params.actor_params, traj_batch, advantages
)

# CALCULATE CRITIC LOSS
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
critic_grads, critic_loss_info = critic_grad_fn(
params.critic_params, traj_batch, targets
)

Expand Down Expand Up @@ -294,15 +301,9 @@ def _critic_loss_fn(
new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state)

# PACK LOSS INFO
total_loss = actor_loss_info[0] + critic_loss_info[0]
value_loss = critic_loss_info[1]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
"actor_loss": actor_loss,
"entropy": entropy,
**actor_loss_info,
**critic_loss_info,
}

return (new_params, new_opt_state), loss_info
Expand Down
2 changes: 1 addition & 1 deletion stoix/systems/q_learning/ff_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _q_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": q_loss_info["q_loss"],
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
2 changes: 1 addition & 1 deletion stoix/systems/q_learning/ff_ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _q_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": q_loss_info["q_loss"],
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
2 changes: 1 addition & 1 deletion stoix/systems/q_learning/ff_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _q_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": q_loss_info["q_loss"],
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
2 changes: 1 addition & 1 deletion stoix/systems/q_learning/ff_dqn_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _q_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": q_loss_info["q_loss"],
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
2 changes: 1 addition & 1 deletion stoix/systems/q_learning/ff_dueling_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _q_loss_fn(

# PACK LOSS INFO
loss_info = {
"total_loss": q_loss_info["q_loss"],
**q_loss_info,
}
return (new_params, new_opt_state, buffer_state, key), loss_info

Expand Down
Loading
Loading