Skip to content

Commit

Permalink
feat: add SPS for trainer (#129)
Browse files Browse the repository at this point in the history
* chore: add SPS for trainer
  • Loading branch information
EdanToledo authored Nov 12, 2024
1 parent d29d31b commit 5c48fea
Show file tree
Hide file tree
Showing 30 changed files with 230 additions and 30 deletions.
14 changes: 13 additions & 1 deletion stoix/systems/awr/ff_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,19 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
act_opt_steps_per_eval = config.arch.num_updates_per_eval * config.system.num_actor_steps
critic_opt_steps_per_eval = (
config.arch.num_updates_per_eval * config.system.num_critic_steps
)
total_opt_steps_per_eval = act_opt_steps_per_eval + critic_opt_steps_per_eval
train_metrics["actor_steps_per_second"] = act_opt_steps_per_eval / elapsed_time
train_metrics["critic_steps_per_second"] = critic_opt_steps_per_eval / elapsed_time
train_metrics["steps_per_second"] = total_opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
14 changes: 13 additions & 1 deletion stoix/systems/awr/ff_awr_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,19 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
act_opt_steps_per_eval = config.arch.num_updates_per_eval * config.system.num_actor_steps
critic_opt_steps_per_eval = (
config.arch.num_updates_per_eval * config.system.num_critic_steps
)
total_opt_steps_per_eval = act_opt_steps_per_eval + critic_opt_steps_per_eval
train_metrics["actor_steps_per_second"] = act_opt_steps_per_eval / elapsed_time
train_metrics["critic_steps_per_second"] = critic_opt_steps_per_eval / elapsed_time
train_metrics["steps_per_second"] = total_opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/ddpg/ff_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/ddpg/ff_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/mpo/ff_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/mpo/ff_mpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/mpo/ff_vmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/mpo/ff_vmpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/ff_dpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/ff_ppo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/ff_ppo_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
10 changes: 9 additions & 1 deletion stoix/systems/ppo/anakin/rec_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,15 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (
config.system.epochs * config.system.num_minibatches
)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/q_learning/ff_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/q_learning/ff_ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/q_learning/ff_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion stoix/systems/q_learning/ff_dqn_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,13 @@ def run_experiment(_config: DictConfig) -> float:
logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
if ep_completed: # only log episode metrics if an episode was completed in the rollout.
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
train_metrics = learner_output.train_metrics
# Calculate the number of optimiser steps per second. Since gradients are aggregated
# across the device and batch axis, we don't consider updates per device/batch as part of
# the SPS for the learner.
opt_steps_per_eval = config.arch.num_updates_per_eval * (config.system.epochs)
train_metrics["steps_per_second"] = opt_steps_per_eval / elapsed_time
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Prepare for evaluation.
start_time = time.time()
Expand Down
Loading

0 comments on commit 5c48fea

Please sign in to comment.