Skip to content

Commit

Permalink
fix(onpolicy_adapter): fix the calculation of last state value (#164)
Browse files Browse the repository at this point in the history
Co-authored-by: borong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and zmsn-2077 committed Mar 26, 2023
1 parent 96e45f3 commit 7248752
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/train_from_custom_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
'train_cfgs': {
'total_steps': 1024000,
'vector_env_nums': 1,
'--parallel': 1,
'parallel': 1,
},
'algo_cfgs': {
'update_cycle': 2048,
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def roll_out( # pylint: disable=too-many-locals
_, last_value_r, last_value_c, _ = agent.step(obs[idx])
if time_out:
_, last_value_r, last_value_c, _ = agent.step(
info[idx]['final observation']
info['final_observation'][idx]
)
last_value_r = last_value_r.unsqueeze(0)
last_value_c = last_value_c.unsqueeze(0)
Expand Down
12 changes: 12 additions & 0 deletions omnisafe/envs/safety_gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import Any, Dict, Optional, Tuple

import numpy as np
import safety_gymnasium
import torch

Expand Down Expand Up @@ -96,6 +97,17 @@ def step(
lambda x: torch.as_tensor(x, dtype=torch.float32),
(obs, reward, cost, terminated, truncated),
)
if 'final_observation' in info:
info['final_observation'] = np.array(
[
array if array is not None else np.zeros(obs.shape[-1])
for array in info['final_observation']
]
)
info['final_observation'] = torch.as_tensor(
info['final_observation'], dtype=torch.float32
)

return obs, reward, cost, terminated, truncated, info

def reset(self, seed: Optional[int] = None) -> Tuple[torch.Tensor, Dict]:
Expand Down
9 changes: 9 additions & 0 deletions omnisafe/envs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def step(
self, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
obs, reward, cost, terminated, truncated, info = super().step(action)
if 'final_observation' in info:
if self.num_envs > 1:
final_obs_slice = info['_final_observation']
else:
final_obs_slice = slice(None)
info['original_final_observation'] = info['final_observation']
info['final_observation'][final_obs_slice] = self._obs_normalizer.normalize(
info['final_observation'][final_obs_slice]
)
info['original_obs'] = obs
obs = self._obs_normalizer.normalize(obs)
return obs, reward, cost, terminated, truncated, info
Expand Down
4 changes: 2 additions & 2 deletions tests/saved_source/benchmark_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ train_cfgs:vector_env_nums:
train_cfgs:torch_threads:
[1]
train_cfgs:total_steps:
1024
2048
algo_cfgs:update_cycle:
512
2048
seed:
[0]
2 changes: 1 addition & 1 deletion tests/saved_source/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ train_cfgs:
vector_env_nums: 1
algo_cfgs:
update_cycle:
512
1024
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_train():
'--custom-cfgs',
'algo_cfgs:update_cycle',
'--custom-cfgs',
'512',
'1024',
],
)
assert result.exit_code == 0
Expand Down

0 comments on commit 7248752

Please sign in to comment.