Skip to content

Commit

Permalink
fix(on-policy): fix the second order algorithms performance (#147)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Gaiejj and pre-commit-ci[bot] authored Mar 13, 2023
1 parent 98b22b6 commit 02cd790
Show file tree
Hide file tree
Showing 18 changed files with 707 additions and 154 deletions.
91 changes: 49 additions & 42 deletions omnisafe/algorithms/off_policy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@ pip install safety_gymnasium
You can set the main function of ``examples/benchmarks/experimrnt_grid.py`` as:

```python
eg = ExperimentGrid(exp_name='Off-Policy-Velocity)
eg = ExperimentGrid(exp_name='Off-Policy-Velocity')

# set up the algorithms.
off_policy = ['DDPG', 'SAC', 'TD3']
eg.add('algo', base_policy)
eg.add('logger_cfgs:use_wandb', [True]) # You can use wandb to monitor the experiment.
eg.add('logger_cfgs:use_tensorboard', [True]) # You can use tensorboard to monitor the experiment.
eg.add('algo', off_policy)

# you can use wandb to monitor the experiment.
eg.add('logger_cfgs:use_wandb', [False])
# you can use tensorboard to monitor the experiment.
eg.add('logger_cfgs:use_tensorboard', [True])

# set up the environment.
eg.add('env_id', [
'SafetyHopperVelocity-v4',
'SafetyWalker2dVelocity-v4',
Expand All @@ -36,8 +43,8 @@ You can set the main function of ``examples/benchmarks/experimrnt_grid.py`` as:
'SafetyHalfCheetahVelocity-v4',
'SafetyHumanoidVelocity-v4'
])
eg.add('seed', [0, 5, 10])
eg.run(train, num_pool=9)
eg.add('seed', [0, 5, 10, 15, 20])
eg.run(train, num_pool=5)
```

After that, you can run the following command to run the benchmark:
Expand Down Expand Up @@ -213,72 +220,72 @@ If you find that other hyperparameters perform better, please feel free to open

| Environment | Reward (OmniSafe) | Cost (Omnisafe) |
| :---------: | :-----------: | :-----------: |
| Ant | **1243.15±619.17** | **289.80±161.52** |
| HalfCheetah | **9496.25±999.36** | **882.63±75.43** |
| Hopper | **2369.89±932.39** | **673.36±278.33** |
| Walker2d | **1648.96±578.43** | **298.20±110.75** |
| Swimmer | **101.63±57.55** | **507.16±152.13** |
| Humanoid | **3254.83±297.52** | **0.00±0.00** |
| SafetyAntVelocity-v4 | **1243.15±619.17** | **289.80±161.52** |
| SafetyHalfCheetahVelocity-v4 | **9496.25±999.36** | **882.63±75.43** |
| SafetyHopperVelocity-v4 | **2369.89±932.39** | **673.36±278.33** |
| SafetyWalker2dVelocity-v4 | **1648.96±578.43** | **298.20±110.75** |
| SafetySwimmerVelocity-v4 | **101.63±57.55** | **507.16±152.13** |
| SafetyHumanoidVelocity-v4 | **3254.83±297.52** | **0.00±0.00** |

#### Hints for DDPG

DDPG only have one Q-network to estimate the Q-value. So it is easy to overestimate the Q-value. In our experiments, we found that the ``gamma`` is important for DDPG. If ``gamma`` is too large, the Q-value will not be overestimated. So we set ``gamma`` to 0.99 in our experiments.

| Environment | obs_normlize | rew_normlize | cost_normlize |
| :---------: | :-----------: | :-----------: | :-----------: |
| Ant | False | False | False |
| HalfCheetah | False | False | False |
| Hopper | False | False | False |
| Humanoid | **True** | **True** | False |
| Walker2d | False | False | False |
| Swimmer | False | True | False |
| SafetyAntVelocity-v4 | False | False | False |
| SafetyHalfCheetahVelocity-v4 | False | False | False |
| SafetyHopperVelocity-v4 | False | False | False |
| SafetyHumanoidVelocity-v4 | **True** | **True** | False |
| SafetyWalker2dVelocity-v4 | False | False | False |
| SafetySwimmerVelocity-v4 | False | True | False |

Please note that the ``cost_normlize`` make no sense for TD3, but work for TD3-Lag.

### TD3

| Environment | Reward (OmniSafe) | Cost (Omnisafe) |
| :---------: | :-----------: | :-----------: |
| Ant | **5107.66±619.95** | **978.33±4.41** |
| HalfCheetah | **8844.27±1812.2** | **981.43±1.08** |
| Hopper | **3567.15±109.79** | **977.43±19.14** |
| Walker2d | **3962.93±355.76** | **904.83±21.69** |
| Swimmer | **81.98±31.23** | **678.66±240.35** |
| Humanoid | **5245.66±674.81** | **0.00±0.00** |
| SafetyAntVelocity-v4 | **5107.66±619.95** | **978.33±4.41** |
| SafetyHalfCheetahVelocity-v4 | **8844.27±1812.2** | **981.43±1.08** |
| SafetyHopperVelocity-v4 | **3567.15±109.79** | **977.43±19.14** |
| SafetyWalker2dVelocity-v4 | **3962.93±355.76** | **904.83±21.69** |
| SafetySwimmerVelocity-v4 | **81.98±31.23** | **678.66±240.35** |
| SafetyHumanoidVelocity-v4 | **5245.66±674.81** | **0.00±0.00** |

#### Hints for TD3

| Environment | obs_normlize | rew_normlize | cost_normlize |
| :---------: | :-----------: | :-----------: | :-----------: |
| Ant | **True** | False | False |
| HalfCheetah | False | False | False |
| Hopper | False | False | False |
| Humanoid | False | False | False |
| Walker2d | False | False | False |
| Swimmer | False | **True** | False |
| SafetyAntVelocity-v4 | **True** | False | False |
| SafetyHalfCheetahVelocity-v4 | False | False | False |
| SafetyHopperVelocity-v4 | False | False | False |
| SafetyHumanoidVelocity-v4 | False | False | False |
| SafetyWalker2dVelocity-v4 | False | False | False |
| SafetySwimmerVelocity-v4 | False | **True** | False |

Please note that the ``cost_normlize`` make no sense for TD3, but work for TD3-Lag.

### SAC

| Environment | Reward (OmniSafe) | Cost (Omnisafe) |
| :---------: | :-----------: | :-----------: |
| Ant | **6061.45±129.37** | **929.53±7.10** |
| HalfCheetah | **10075.95±423.83** | **981.23±1.06** |
| Hopper | **3386.41±89.95** | **992.76±0.16** |
| Walker2d | **4613.00±340.90** | **914.56±14.91** |
| Swimmer | **44.80±3.65** | **376.50±152.89** |
| Humanoid | **5618.22±337.33** | **0.00±0.00** |
| SafetyAntVelocity-v4 | **6061.45±129.37** | **929.53±7.10** |
| SafetyHalfCheetahVelocity-v4 | **10075.95±423.83** | **981.23±1.06** |
| SafetyHopperVelocity-v4 | **3386.41±89.95** | **992.76±0.16** |
| SafetyWalker2dVelocity-v4 | **4613.00±340.90** | **914.56±14.91** |
| SafetySwimmerVelocity-v4 | **44.80±3.65** | **376.50±152.89** |
| SafetyHumanoidVelocity-v4 | **5618.22±337.33** | **0.00±0.00** |

#### Hints for SAC

| Environment | obs_normlize | rew_normlize | cost_normlize |
| :---------: | :-----------: | :-----------: | :-----------: |
| Ant | False | False | False |
| HalfCheetah | False | False | False |
| Hopper | False | False | False |
| Humanoid | False | False | False |
| Walker2d | False | False | False |
| Swimmer | False | **True** | False |
| SafetyAntVelocity-v4 | False | False | False |
| SafetyHalfCheetahVelocity-v4 | False | False | False |
| SafetyHopperVelocity-v4 | False | False | False |
| SafetyHumanoidVelocity-v4 | False | False | False |
| SafetyWalker2dVelocity-v4 | False | False | False |
| SafetySwimmerVelocity-v4 | False | **True** | False |

Please note that the ``cost_normlize`` make no sense for TD3, but work for TD3-Lag.
Loading

0 comments on commit 02cd790

Please sign in to comment.