Skip to content

Commit

Permalink
feat(off-policy): fix final_obsevation setting and support evaluation…
Browse files Browse the repository at this point in the history
… times configuation (#260)
  • Loading branch information
Gaiejj authored Aug 1, 2023
1 parent 7fcfe78 commit 9e76d28
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 6 deletions.
3 changes: 2 additions & 1 deletion omnisafe/adapter/offpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def rollout( # pylint: disable=too-many-locals
real_next_obs = next_obs.clone()
for idx, done in enumerate(torch.logical_or(terminated, truncated)):
if done:
real_next_obs[idx] = info['final_observation'][idx]
if 'final_observation' in info:
real_next_obs[idx] = info['final_observation'][idx]
self._log_metrics(logger, idx)
self._reset_log(idx)

Expand Down
1 change: 0 additions & 1 deletion omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__( # pylint: disable=too-many-arguments
)

self._env.set_seed(seed)
self._eval_env.set_seed(seed)

def _wrapper(
self,
Expand Down
9 changes: 5 additions & 4 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def _init_log(self) -> None:
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)

self._logger.register_key('Metrics/TestEpRet', window_length=50)
self._logger.register_key('Metrics/TestEpCost', window_length=50)
self._logger.register_key('Metrics/TestEpLen', window_length=50)
if self._cfgs.train_cfgs.eval_episodes > 0:
self._logger.register_key('Metrics/TestEpRet', window_length=50)
self._logger.register_key('Metrics/TestEpCost', window_length=50)
self._logger.register_key('Metrics/TestEpLen', window_length=50)

self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/LR')
Expand Down Expand Up @@ -283,7 +284,7 @@ def learn(self) -> tuple[float, float, float]:

eval_start = time.time()
self._env.eval_policy(
episode=1,
episode=self._cfgs.train_cfgs.eval_episodes,
agent=self._actor_critic,
logger=self._logger,
)
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/DDPG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/DDPGLag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/DDPGPID.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/SAC.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/SACLag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/SACPID.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/TD3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/TD3Lag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/configs/off-policy/TD3PID.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ defaults:
parallel: 1
# total number of steps to train
total_steps: 1000000
# number of evaluate episodes
eval_episodes: 1
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
Expand Down

0 comments on commit 9e76d28

Please sign in to comment.