Skip to content

Commit

Permalink
fix(onpolicy_adapter): fix the calculation of last state value (#162)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: borong <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2023
1 parent 3c2daf0 commit 973b363
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
10 changes: 7 additions & 3 deletions omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ def roll_out( # pylint: disable=too-many-locals
epoch_end = step >= steps_per_epoch - 1
for idx, (done, time_out) in enumerate(zip(terminated, truncated)):
if epoch_end or done or time_out:
if (epoch_end or time_out) and not done:
if not done:
if epoch_end:
logger.log(
f'Warning: trajectory cut off when rollout by epoch at {self._ep_len[idx]} steps.'
)
_, last_value_r, last_value_c, _ = agent.step(obs[idx])
_, last_value_r, last_value_c, _ = agent.step(obs[idx])
if time_out:
_, last_value_r, last_value_c, _ = agent.step(
info[idx]['final observation']
)
last_value_r = last_value_r.unsqueeze(0)
last_value_c = last_value_c.unsqueeze(0)
elif done:
else:
last_value_r = torch.zeros(1)
last_value_c = torch.zeros(1)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typer.testing import CliRunner

from omnisafe import app
from omnisafe.utils.command_app import app


runner = CliRunner()
Expand Down Expand Up @@ -76,6 +76,7 @@ def test_eval():
'1',
'--height',
'1',
'--no-render',
],
)
assert result.exit_code == 0
2 changes: 1 addition & 1 deletion tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_off_policy(algo):
},
'algo_cfgs': {
'update_cycle': 1024,
'step_per_sample': 1024,
'steps_per_sample': 1024,
'update_iters': 1,
'start_learning_steps': 0,
},
Expand Down
14 changes: 6 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def train(
USE_REDIRECTION = True
if USE_REDIRECTION:
if not os.path.exists(custom_cfgs['data_dir']):
os.makedirs(custom_cfgs['data_dir'])
os.makedirs(custom_cfgs['data_dir'], exist_ok=True)
sys.stdout = open(f'{custom_cfgs["data_dir"]}terminal.log', 'w', encoding='utf-8')
sys.stderr = open(f'{custom_cfgs["data_dir"]}error.log', 'w', encoding='utf-8')
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
Expand All @@ -128,20 +128,18 @@ def train(


def test_train(
exp_name='Safety_Gymnasium_Goal',
exp_name='make_test_exp_grid',
algo='CPO',
env_id='SafetyHalfCheetahVelocity-v4',
epochs=1,
steps_per_epoch=1000,
num_envs=1,
):
"""Test train."""
eg = ExperimentGrid(exp_name=exp_name)
eg.add('algo', [algo])
eg.add('env_id', [env_id])
eg.add('epochs', [epochs])
eg.add('steps_per_epoch', [steps_per_epoch])
eg.add('env_cfgs', [{'num_envs': num_envs}])
eg.add('logger_cfgs:use_wandb', [False])
eg.add('algo_cfgs:update_cycle', [512])
eg.add('train_cfgs:total_steps', [1024, 2048])
eg.add('train_cfgs:vector_env_nums', [1])
eg.run(train, num_pool=1, is_test=True)


Expand Down

0 comments on commit 973b363

Please sign in to comment.