From 47dd62735d4ba367111e94739b1bd63fc6b165bf Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 19 Apr 2024 15:12:22 +0100 Subject: [PATCH 1/2] chore: refactor the metrics returned by losses --- stoix/systems/ddpg/ff_d4pg.py | 5 ++- stoix/systems/ddpg/ff_ddpg.py | 5 ++- stoix/systems/ddpg/ff_td3.py | 32 ++++++++++---------- stoix/systems/mpo/ff_mpo.py | 6 ++-- stoix/systems/mpo/ff_mpo_continuous.py | 6 ++-- stoix/systems/ppo/ff_dpo_continuous.py | 29 +++++++++--------- stoix/systems/ppo/ff_ppo.py | 29 +++++++++--------- stoix/systems/ppo/ff_ppo_continuous.py | 30 +++++++++--------- stoix/systems/ppo/rec_ppo.py | 29 +++++++++--------- stoix/systems/q_learning/ff_c51.py | 2 +- stoix/systems/q_learning/ff_ddqn.py | 2 +- stoix/systems/q_learning/ff_dqn.py | 2 +- stoix/systems/q_learning/ff_dqn_reg.py | 2 +- stoix/systems/q_learning/ff_dueling_dqn.py | 2 +- stoix/systems/q_learning/ff_mdqn.py | 2 +- stoix/systems/q_learning/ff_qr_dqn.py | 2 +- stoix/systems/sac/ff_sac.py | 11 ++----- stoix/systems/search/ff_az.py | 29 +++++++++--------- stoix/systems/search/ff_mz.py | 19 ++---------- stoix/systems/search/ff_sampled_az.py | 29 +++++++++--------- stoix/systems/search/ff_sampled_mz.py | 19 ++---------- stoix/systems/vpg/ff_reinforce.py | 29 +++++++++--------- stoix/systems/vpg/ff_reinforce_continuous.py | 24 ++++++--------- 23 files changed, 158 insertions(+), 187 deletions(-) diff --git a/stoix/systems/ddpg/ff_d4pg.py b/stoix/systems/ddpg/ff_d4pg.py index 435f8446..1bcc6d93 100644 --- a/stoix/systems/ddpg/ff_d4pg.py +++ b/stoix/systems/ddpg/ff_d4pg.py @@ -302,9 +302,8 @@ def _actor_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"], - "value_loss": q_loss_info["q_loss"], - "actor_loss": actor_loss_info["actor_loss"], + **actor_loss_info, + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/ddpg/ff_ddpg.py b/stoix/systems/ddpg/ff_ddpg.py index f7e7018f..028e3741 100644 --- a/stoix/systems/ddpg/ff_ddpg.py +++ b/stoix/systems/ddpg/ff_ddpg.py @@ -291,9 +291,8 @@ def _actor_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"], - "value_loss": q_loss_info["q_loss"], - "actor_loss": actor_loss_info["actor_loss"], + **actor_loss_info, + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/ddpg/ff_td3.py b/stoix/systems/ddpg/ff_td3.py index 00311086..3ae6ed06 100644 --- a/stoix/systems/ddpg/ff_td3.py +++ b/stoix/systems/ddpg/ff_td3.py @@ -243,7 +243,7 @@ def _actor_loss_fn( } return actor_loss, loss_info - params, opt_states, buffer_state, key, epoch_counter = update_state + params, opt_states, buffer_state, key = update_state key, sample_key = jax.random.split(key, num=2) @@ -305,31 +305,24 @@ def _actor_loss_fn( q_new_params = QsAndTarget(q_new_online_params, new_target_q_params) # PACK NEW PARAMS AND OPTIMISER STATE - # Delayed policy updates - time_to_update = jnp.mod(epoch_counter, config.system.policy_frequency) == 0 - actor_new_params = jax.lax.cond( - time_to_update, lambda _: actor_new_params, lambda _: params.actor_params, None - ) new_params = DDPGParams(actor_new_params, q_new_params) new_opt_state = DDPGOptStates(actor_new_opt_state, q_new_opt_state) # PACK LOSS INFO loss_info = { - "total_loss": actor_loss_info["actor_loss"] + q_loss_info["q_loss"], - "value_loss": q_loss_info["q_loss"], - "actor_loss": actor_loss_info["actor_loss"], + **actor_loss_info, + **q_loss_info, } - return (new_params, new_opt_state, buffer_state, key, epoch_counter + 1), loss_info + return (new_params, new_opt_state, buffer_state, key), loss_info - epoch_counter = jnp.array(0) - update_state = (params, opt_states, buffer_state, key, epoch_counter) + update_state = (params, opt_states, buffer_state, key) # UPDATE EPOCHS update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.epochs ) - params, opt_states, buffer_state, key, epoch_counter = update_state + params, opt_states, buffer_state, key = update_state learner_state = DDPGLearnerState( params, opt_states, buffer_state, key, env_state, last_timestep ) @@ -398,9 +391,16 @@ def learner_setup( actor_lr = make_learning_rate(config.system.actor_lr, config, config.system.epochs) q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs) - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(actor_lr, eps=1e-5), + def delayed_policy_update(step_count: int) -> bool: + should_update: bool = jnp.mod(step_count, config.system.policy_frequency) == 0 + return should_update + + actor_optim = optax.maybe_update( + optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ), + delayed_policy_update, ) q_optim = optax.chain( optax.clip_by_global_norm(config.system.max_grad_norm), diff --git a/stoix/systems/mpo/ff_mpo.py b/stoix/systems/mpo/ff_mpo.py index 4804c05f..c16c39ff 100644 --- a/stoix/systems/mpo/ff_mpo.py +++ b/stoix/systems/mpo/ff_mpo.py @@ -382,10 +382,10 @@ def _q_loss_fn( new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state) # PACK LOSS INFO - loss_info = actor_loss_info._asdict() + actor_loss_info = actor_loss_info._asdict() loss_info = { - **loss_info, - "value_loss": q_loss_info["q_loss"], + **actor_loss_info, + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/mpo/ff_mpo_continuous.py b/stoix/systems/mpo/ff_mpo_continuous.py index 30683424..202df471 100644 --- a/stoix/systems/mpo/ff_mpo_continuous.py +++ b/stoix/systems/mpo/ff_mpo_continuous.py @@ -398,10 +398,10 @@ def _q_loss_fn( new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state) # PACK LOSS INFO - loss_info = actor_loss_info._asdict() + actor_loss_info = actor_loss_info._asdict() loss_info = { - **loss_info, - "value_loss": q_loss_info["q_loss"], + **actor_loss_info, + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/ppo/ff_dpo_continuous.py b/stoix/systems/ppo/ff_dpo_continuous.py index 92e88fe5..8029c444 100644 --- a/stoix/systems/ppo/ff_dpo_continuous.py +++ b/stoix/systems/ppo/ff_dpo_continuous.py @@ -152,7 +152,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -169,12 +173,15 @@ def _critic_loss_fn( ) critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info # CALCULATE ACTOR LOSS key, actor_loss_key = jax.random.split(key) - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch, advantages, @@ -182,8 +189,8 @@ def _critic_loss_fn( ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -224,15 +231,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state, key), loss_info diff --git a/stoix/systems/ppo/ff_ppo.py b/stoix/systems/ppo/ff_ppo.py index 386604ca..bf3ae7bd 100644 --- a/stoix/systems/ppo/ff_ppo.py +++ b/stoix/systems/ppo/ff_ppo.py @@ -151,7 +151,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy().mean() total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -168,17 +172,20 @@ def _critic_loss_fn( ) critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch, advantages ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -219,15 +226,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state), loss_info diff --git a/stoix/systems/ppo/ff_ppo_continuous.py b/stoix/systems/ppo/ff_ppo_continuous.py index 536fcc98..59ee305b 100644 --- a/stoix/systems/ppo/ff_ppo_continuous.py +++ b/stoix/systems/ppo/ff_ppo_continuous.py @@ -152,7 +152,12 @@ def _actor_loss_fn( entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -169,12 +174,15 @@ def _critic_loss_fn( ) critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info # CALCULATE ACTOR LOSS key, actor_loss_key = jax.random.split(key) - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch, advantages, @@ -182,8 +190,8 @@ def _critic_loss_fn( ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -224,15 +232,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state, key), loss_info diff --git a/stoix/systems/ppo/rec_ppo.py b/stoix/systems/ppo/rec_ppo.py index 93607f58..35d42cf5 100644 --- a/stoix/systems/ppo/rec_ppo.py +++ b/stoix/systems/ppo/rec_ppo.py @@ -223,7 +223,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy().mean() total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -244,17 +248,20 @@ def _critic_loss_fn( ) total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return total_loss, loss_info # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch, advantages ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -294,15 +301,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state), loss_info diff --git a/stoix/systems/q_learning/ff_c51.py b/stoix/systems/q_learning/ff_c51.py index 9342bb64..103e7dbf 100644 --- a/stoix/systems/q_learning/ff_c51.py +++ b/stoix/systems/q_learning/ff_c51.py @@ -214,7 +214,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_ddqn.py b/stoix/systems/q_learning/ff_ddqn.py index 27fb9eca..b5741c91 100644 --- a/stoix/systems/q_learning/ff_ddqn.py +++ b/stoix/systems/q_learning/ff_ddqn.py @@ -206,7 +206,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_dqn.py b/stoix/systems/q_learning/ff_dqn.py index 5847ded4..657df3a7 100644 --- a/stoix/systems/q_learning/ff_dqn.py +++ b/stoix/systems/q_learning/ff_dqn.py @@ -205,7 +205,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_dqn_reg.py b/stoix/systems/q_learning/ff_dqn_reg.py index 8f06bed9..17e0904b 100644 --- a/stoix/systems/q_learning/ff_dqn_reg.py +++ b/stoix/systems/q_learning/ff_dqn_reg.py @@ -209,7 +209,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_dueling_dqn.py b/stoix/systems/q_learning/ff_dueling_dqn.py index 8fb9cc0c..ae4f6585 100644 --- a/stoix/systems/q_learning/ff_dueling_dqn.py +++ b/stoix/systems/q_learning/ff_dueling_dqn.py @@ -205,7 +205,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_mdqn.py b/stoix/systems/q_learning/ff_mdqn.py index ff1611f0..fab1ea19 100644 --- a/stoix/systems/q_learning/ff_mdqn.py +++ b/stoix/systems/q_learning/ff_mdqn.py @@ -209,7 +209,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/q_learning/ff_qr_dqn.py b/stoix/systems/q_learning/ff_qr_dqn.py index 154a384d..bc963772 100644 --- a/stoix/systems/q_learning/ff_qr_dqn.py +++ b/stoix/systems/q_learning/ff_qr_dqn.py @@ -230,7 +230,7 @@ def _q_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": q_loss_info["q_loss"], + **q_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/sac/ff_sac.py b/stoix/systems/sac/ff_sac.py index 8d93b8f9..bbad0136 100644 --- a/stoix/systems/sac/ff_sac.py +++ b/stoix/systems/sac/ff_sac.py @@ -292,14 +292,9 @@ def _actor_loss_fn( # PACK LOSS INFO loss_info = { - "total_loss": actor_loss_info["actor_loss"] - + q_loss_info["q_loss"] - + alpha_loss_info["alpha_loss"], - "value_loss": q_loss_info["q_loss"], - "actor_loss": actor_loss_info["actor_loss"], - "entropy": actor_loss_info["entropy"], - "alpha_loss": alpha_loss_info["alpha_loss"], - "alpha": alpha_loss_info["alpha"], + **actor_loss_info, + **q_loss_info, + **alpha_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/search/ff_az.py b/stoix/systems/search/ff_az.py index 0f73c9e5..1a16ad85 100644 --- a/stoix/systems/search/ff_az.py +++ b/stoix/systems/search/ff_az.py @@ -249,7 +249,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy().mean() total_loss_actor = actor_loss - config.system.ent_coef * entropy - return total_loss_actor, (actor_loss, entropy) + loss_info = { + "actor_loss": actor_loss, + "entropy": entropy, + } + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -271,7 +275,10 @@ def _critic_loss_fn( value_loss = rlax.l2_loss(value, targets).mean() critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info params, opt_states, buffer_state, key = update_state @@ -282,12 +289,12 @@ def _critic_loss_fn( sequence: ExItTransition = sequence_sample.experience # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn(params.actor_params, sequence) + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn(params.actor_params, sequence) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn(params.critic_params, sequence) + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn(params.critic_params, sequence) # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. @@ -326,15 +333,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/search/ff_mz.py b/stoix/systems/search/ff_mz.py index b1af251c..115822ad 100644 --- a/stoix/systems/search/ff_mz.py +++ b/stoix/systems/search/ff_mz.py @@ -386,7 +386,7 @@ def unroll_fn( losses["actor"] + losses["value"] + losses["reward"] - losses["entropy"] ) - return total_loss, (losses) + return total_loss, losses params, opt_state, buffer_state, key = update_state @@ -397,8 +397,8 @@ def unroll_fn( sequence: ExItTransition = sequence_sample.experience # CALCULATE LOSS - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - loss_info, grads = grad_fn( + grad_fn = jax.grad(_loss_fn, has_aux=True) + grads, loss_info = grad_fn( params, sequence, ) @@ -415,19 +415,6 @@ def unroll_fn( updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1]["actor"] - value_loss = loss_info[1]["value"] - entropy = loss_info[1]["entropy"] - reward_loss = loss_info[1]["reward"] - loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, - "reward_loss": reward_loss, - } return (new_params, new_opt_state, buffer_state, key), loss_info update_state = (params, opt_state, buffer_state, key) diff --git a/stoix/systems/search/ff_sampled_az.py b/stoix/systems/search/ff_sampled_az.py index f705b2d0..faf25371 100644 --- a/stoix/systems/search/ff_sampled_az.py +++ b/stoix/systems/search/ff_sampled_az.py @@ -383,7 +383,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = actor_loss - config.system.ent_coef * entropy - return total_loss_actor, (actor_loss, entropy) + loss_info = { + "actor_loss": actor_loss, + "entropy": entropy, + } + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -405,7 +409,10 @@ def _critic_loss_fn( value_loss = rlax.l2_loss(value, targets).mean() critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info params, opt_states, buffer_state, key = update_state @@ -416,12 +423,12 @@ def _critic_loss_fn( sequence: SampledExItTransition = sequence_sample.experience # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn(params.actor_params, sequence, actor_key) + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn(params.actor_params, sequence, actor_key) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn(params.critic_params, sequence) + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn(params.critic_params, sequence) # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. @@ -460,15 +467,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } return (new_params, new_opt_state, buffer_state, key), loss_info diff --git a/stoix/systems/search/ff_sampled_mz.py b/stoix/systems/search/ff_sampled_mz.py index eab39686..57af0b05 100644 --- a/stoix/systems/search/ff_sampled_mz.py +++ b/stoix/systems/search/ff_sampled_mz.py @@ -520,7 +520,7 @@ def unroll_fn( losses["actor"] + losses["value"] + losses["reward"] - losses["entropy"] ) - return total_loss, (losses) + return total_loss, losses params, opt_state, buffer_state, key = update_state @@ -531,8 +531,8 @@ def unroll_fn( sequence: SampledExItTransition = sequence_sample.experience # CALCULATE LOSS - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - loss_info, grads = grad_fn(params, sequence, loss_key) + grad_fn = jax.grad(_loss_fn, has_aux=True) + grads, loss_info = grad_fn(params, sequence, loss_key) # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. @@ -546,19 +546,6 @@ def unroll_fn( updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1]["actor"] - value_loss = loss_info[1]["value"] - entropy = loss_info[1]["entropy"] - reward_loss = loss_info[1]["reward"] - loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, - "reward_loss": reward_loss, - } return (new_params, new_opt_state, buffer_state, key), loss_info update_state = (params, opt_state, buffer_state, key) diff --git a/stoix/systems/vpg/ff_reinforce.py b/stoix/systems/vpg/ff_reinforce.py index 979b15fe..c00b9cf8 100644 --- a/stoix/systems/vpg/ff_reinforce.py +++ b/stoix/systems/vpg/ff_reinforce.py @@ -105,7 +105,11 @@ def _actor_loss_fn( entropy = actor_policy.entropy().mean() total_loss_actor = loss_actor.mean() - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -120,11 +124,14 @@ def _critic_loss_fn( value_loss = rlax.l2_loss(value, targets).mean() critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info # CALCULATE ACTOR LOSS - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch.obs, traj_batch.action, @@ -133,8 +140,8 @@ def _critic_loss_fn( ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch.obs, monte_carlo_returns ) @@ -175,15 +182,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } learner_state = LearnerState(new_params, new_opt_state, key, env_state, last_timestep) diff --git a/stoix/systems/vpg/ff_reinforce_continuous.py b/stoix/systems/vpg/ff_reinforce_continuous.py index e327478e..83920da5 100644 --- a/stoix/systems/vpg/ff_reinforce_continuous.py +++ b/stoix/systems/vpg/ff_reinforce_continuous.py @@ -106,7 +106,8 @@ def _actor_loss_fn( entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = loss_actor.mean() - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + loss_info = {"actor_loss": loss_actor, "entropy": entropy} + return total_loss_actor, loss_info def _critic_loss_fn( critic_params: FrozenDict, @@ -121,12 +122,13 @@ def _critic_loss_fn( value_loss = rlax.l2_loss(value, targets).mean() critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + loss_info = {"value_loss": value_loss} + return critic_total_loss, loss_info # CALCULATE ACTOR LOSS key, actor_loss_key = jax.random.split(key) - actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) - actor_loss_info, actor_grads = actor_grad_fn( + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( params.actor_params, traj_batch.obs, traj_batch.action, @@ -136,8 +138,8 @@ def _critic_loss_fn( ) # CALCULATE CRITIC LOSS - critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( params.critic_params, traj_batch.obs, monte_carlo_returns ) @@ -178,15 +180,9 @@ def _critic_loss_fn( new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] loss_info = { - "total_loss": total_loss, - "value_loss": value_loss, - "actor_loss": actor_loss, - "entropy": entropy, + **actor_loss_info, + **critic_loss_info, } learner_state = LearnerState(new_params, new_opt_state, key, env_state, last_timestep) From b1ffc2d68430624ecb221bb296f6de07acf703d2 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 19 Apr 2024 17:46:08 +0000 Subject: [PATCH 2/2] fix: td3 policy delay optax mask --- stoix/systems/ddpg/ff_td3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stoix/systems/ddpg/ff_td3.py b/stoix/systems/ddpg/ff_td3.py index 3ae6ed06..a54dfa8f 100644 --- a/stoix/systems/ddpg/ff_td3.py +++ b/stoix/systems/ddpg/ff_td3.py @@ -395,12 +395,12 @@ def delayed_policy_update(step_count: int) -> bool: should_update: bool = jnp.mod(step_count, config.system.policy_frequency) == 0 return should_update - actor_optim = optax.maybe_update( + actor_optim = optax.conditionally_mask( optax.chain( optax.clip_by_global_norm(config.system.max_grad_norm), optax.adam(actor_lr, eps=1e-5), ), - delayed_policy_update, + should_transform_fn=delayed_policy_update, ) q_optim = optax.chain( optax.clip_by_global_norm(config.system.max_grad_norm),