Skip to content

Commit

Permalink
chore(off-policy): clear redundant distributed grad average (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Jul 2, 2023
1 parent 1a99b20 commit af2951d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
16 changes: 9 additions & 7 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ def _init_config(self) -> Config:
self.algo in ALGORITHMS['all']
), f"{self.algo} doesn't exist. Please choose from {ALGORITHMS['all']}."
self.algo_type = ALGORITHM2TYPE.get(self.algo, '')
if self.algo_type in ['model-based', 'offline'] and self.train_terminal_cfgs is not None:
assert (
self.train_terminal_cfgs['parallel'] == 1
), 'model-based and offline only support parallel==1!'
assert (
self.train_terminal_cfgs['vector_env_nums'] == 1
), 'model-based and offline only support vector_env_nums==1!'
if self.train_terminal_cfgs is not None:
if self.algo_type in ['model-based', 'offline']:
assert (
self.train_terminal_cfgs['vector_env_nums'] == 1
), 'model-based and offline only support vector_env_nums==1!'
if self.algo_type in ['off-policy', 'model-based', 'offline']:
assert (
self.train_terminal_cfgs['parallel'] == 1
), 'off-policy, model-based and offline only support parallel==1!'

cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type)

Expand Down
14 changes: 6 additions & 8 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
from omnisafe.utils import distributed


@registry.register
Expand Down Expand Up @@ -70,9 +69,9 @@ def _init_env(self) -> None:
self._seed,
self._cfgs,
)
assert (self._cfgs.algo_cfgs.steps_per_epoch) % (
distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums
) == 0, 'The number of steps per epoch is not divisible by the number of environments.'
assert (
self._cfgs.algo_cfgs.steps_per_epoch % self._cfgs.train_cfgs.vector_env_nums == 0
), 'The number of steps per epoch is not divisible by the number of environments.'

assert (
int(self._cfgs.train_cfgs.total_steps) % self._cfgs.algo_cfgs.steps_per_epoch == 0
Expand All @@ -81,9 +80,10 @@ def _init_env(self) -> None:
self._cfgs.train_cfgs.total_steps // self._cfgs.algo_cfgs.steps_per_epoch,
)
self._epoch: int = 0
self._steps_per_epoch: int = self._cfgs.algo_cfgs.steps_per_epoch // (
distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums
self._steps_per_epoch: int = (
self._cfgs.algo_cfgs.steps_per_epoch // self._cfgs.train_cfgs.vector_env_nums
)

self._update_cycle: int = self._cfgs.algo_cfgs.update_cycle
assert (
self._steps_per_epoch % self._update_cycle == 0
Expand Down Expand Up @@ -420,7 +420,6 @@ def _update_reward_critic(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.reward_critic)
self._actor_critic.reward_critic_optimizer.step()

def _update_cost_critic(
Expand Down Expand Up @@ -463,7 +462,6 @@ def _update_cost_critic(
self._actor_critic.cost_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.cost_critic)
self._actor_critic.cost_critic_optimizer.step()

self._logger.store(
Expand Down
2 changes: 0 additions & 2 deletions omnisafe/algorithms/off_policy/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
from omnisafe.utils import distributed


@registry.register
Expand Down Expand Up @@ -142,7 +141,6 @@ def _update_reward_critic(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.reward_critic)
self._actor_critic.reward_critic_optimizer.step()
self._logger.store(
{
Expand Down
2 changes: 0 additions & 2 deletions omnisafe/algorithms/off_policy/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
from omnisafe.utils import distributed


@registry.register
Expand Down Expand Up @@ -108,7 +107,6 @@ def _update_reward_critic(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.reward_critic)
self._actor_critic.reward_critic_optimizer.step()
self._logger.store(
{
Expand Down

0 comments on commit af2951d

Please sign in to comment.