Skip to content

Commit

Permalink
Merge branch 'PKU-Alignment:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Feb 17, 2024
2 parents d3f096c + 51a2692 commit 4e54ba3
Show file tree
Hide file tree
Showing 30 changed files with 493 additions and 177 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:
default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -29,7 +29,7 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.284
rev: v0.0.292
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -38,11 +38,11 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -63,7 +63,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
23 changes: 0 additions & 23 deletions CODEOWNERS

This file was deleted.

6 changes: 3 additions & 3 deletions benchmarks/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ pip install safety_gymnasium
## Training agents used to generate data

```bash
omnisafe train --env-id SafetyAntVelocity-v1 --algo PPO
omnisafe train --env-id SafetyAntVelocity-v1 --algo PPOLag
```

## Collect offline data

The `PATH_TO_AGENT` is the path of the directory containing the `torch_save`.

```python
from omnisafe.common.offline.data_collector import OfflineDataCollector

Expand All @@ -40,8 +41,7 @@ from omnisafe.common.offline.data_collector import OfflineDataCollector
env_name = 'SafetyAntVelocity-v1'
size = 1_000_000
agents = [
('./runs/PPO', 'epoch-500', 500_000),
('./runs/PPOLag', 'epoch-500', 500_000),
('PATH_TO_AGENT', 'epoch-500.pt', 1_000_000),
]
save_dir = './data'

Expand Down
172 changes: 108 additions & 64 deletions benchmarks/on-policy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1468,22 +1468,22 @@ class="math inline">±</span> 30.93</td>
<tr class="even">
<td style="text-align: left;"><span
class="smallcaps">SafetyCarGoal1-v0</span></td>
<td style="text-align: center;">-0.65 <span class="math inline">±</span>
2.89</td>
<td style="text-align: center;">22.90 <span class="math inline">±</span>
16.85</td>
<td style="text-align: center;">1.89 <span class="math inline">±</span>
3.52</td>
<td style="text-align: center;">4.86 <span class="math inline">±</span>
3.11</td>
<td style="text-align: center;">0.81 <span class="math inline">±</span>
0.41</td>
<td style="text-align: center;">17.18 <span class="math inline">±</span>
14.30</td>
<td style="text-align: center;">0.24 <span class="math inline">±</span>
0.41</td>
<td style="text-align: center;">0.90 <span class="math inline">±</span>
1.18</td>
<td style="text-align: center;">7.12 <span class="math inline">±</span>
5.41</td>
<td style="text-align: center;">21.68 <span class="math inline">±</span>
29.11</td>
<td style="text-align: center;">16.67 <span class="math inline">±</span>
10.57</td>
<td style="text-align: center;">23.58 <span class="math inline">±</span>
26.39</td>
<td style="text-align: center;">8.45 <span class="math inline">±</span>
7.16</td>
<td style="text-align: center;">18.98 <span class="math inline">±</span>
25.63</td>
<td style="text-align: center;">15.08 <span class="math inline">±</span>
13.41</td>
<td style="text-align: center;">23.22 <span class="math inline">±</span>
19.80</td>
</tr>
<tr class="odd">
<td style="text-align: left;"><span
Expand All @@ -1508,22 +1508,22 @@ class="smallcaps">SafetyCarButton1-v0</span></td>
<tr class="even">
<td style="text-align: left;"><span
class="smallcaps">SafetyCarGoal2-v0</span></td>
<td style="text-align: center;">-0.87 <span class="math inline">±</span>
0.79</td>
<td style="text-align: center;">6.13 <span class="math inline">±</span>
4.51</td>
<td style="text-align: center;">-1.03 <span class="math inline">±</span>
1.46</td>
<td style="text-align: center;">18.07 <span class="math inline">±</span>
11.62</td>
<td style="text-align: center;">-0.96 <span class="math inline">±</span>
1.10</td>
<td style="text-align: center;">3.00 <span class="math inline">±</span>
0.83</td>
<td style="text-align: center;">-0.36 <span class="math inline">±</span>
0.09</td>
<td style="text-align: center;">9.83 <span class="math inline">±</span>
13.91</td>
<td style="text-align: center;">0.90 <span class="math inline">±</span>
1.20</td>
<td style="text-align: center;">19.98 <span class="math inline">±</span>
10.12</td>
<td style="text-align: center;">1.76 <span class="math inline">±</span>
5.20</td>
<td style="text-align: center;">31.50 <span class="math inline">±</span>
45.50</td>
<td style="text-align: center;">1.02 <span class="math inline">±</span>
1.41</td>
<td style="text-align: center;">27.32 <span class="math inline">±</span>
60.12</td>
<td style="text-align: center;">0.93 <span class="math inline">±</span>
2.21</td>
<td style="text-align: center;">26.66 <span class="math inline">±</span>
60.07</td>
</tr>
<tr class="odd">
<td style="text-align: left;"><span
Expand All @@ -1548,22 +1548,22 @@ class="smallcaps">SafetyCarButton2-v0</span></td>
<tr class="even">
<td style="text-align: left;"><span
class="smallcaps">SafetyPointGoal1-v0</span></td>
<td style="text-align: center;">1.99 <span class="math inline">±</span>
2.87</td>
<td style="text-align: center;">7.80 <span class="math inline">±</span>
2.78</td>
<td style="text-align: center;">1.02 <span class="math inline">±</span>
0.80</td>
<td style="text-align: center;">7.46 <span class="math inline">±</span>
5.26</td>
<td style="text-align: center;">1.69 <span class="math inline">±</span>
3.25</td>
<td style="text-align: center;">5.34 <span class="math inline">±</span>
10.33</td>
<td style="text-align: center;">1.38 <span class="math inline">±</span>
1.91</td>
<td style="text-align: center;">1.34 <span class="math inline">±</span>
1.82</td>
<td style="text-align: center;">7.06 <span class="math inline">±</span>
5.85</td>
<td style="text-align: center;">20.04 <span class="math inline">±</span>
21.91</td>
<td style="text-align: center;">16.18 <span class="math inline">±</span>
9.55</td>
<td style="text-align: center;">29.94 <span class="math inline">±</span>
26.68</td>
<td style="text-align: center;">8.30 <span class="math inline">±</span>
6.03</td>
<td style="text-align: center;">25.32 <span class="math inline">±</span>
31.91</td>
<td style="text-align: center;">11.64 <span class="math inline">±</span>
8.46</td>
<td style="text-align: center;">30.00 <span class="math inline">±</span>
27.67</td>
</tr>
<tr class="odd">
<td style="text-align: left;"><span
Expand All @@ -1588,22 +1588,22 @@ class="smallcaps">SafetyPointButton1-v0</span></td>
<tr class="even">
<td style="text-align: left;"><span
class="smallcaps">SafetyPointGoal2-v0</span></td>
<td style="text-align: center;">-1.85 <span class="math inline">±</span>
0.99</td>
<td style="text-align: center;">21.77 <span class="math inline">±</span>
13.56</td>
<td style="text-align: center;">-1.38 <span class="math inline">±</span>
1.16</td>
<td style="text-align: center;">7.87 <span class="math inline">±</span>
2.02</td>
<td style="text-align: center;">-1.13 <span class="math inline">±</span>
0.39</td>
<td style="text-align: center;">7.03 <span class="math inline">±</span>
4.21</td>
<td style="text-align: center;">-0.54 <span class="math inline">±</span>
0.18</td>
<td style="text-align: center;">26.57 <span class="math inline">±</span>
19.13</td>
<td style="text-align: center;">0.84 <span class="math inline">±</span>
2.93</td>
<td style="text-align: center;">14.06 <span class="math inline">±</span>
30.21</td>
<td style="text-align: center;">1.64 <span class="math inline">±</span>
4.02</td>
<td style="text-align: center;">19.00 <span class="math inline">±</span>
34.69</td>
<td style="text-align: center;">0.56 <span class="math inline">±</span>
2.52</td>
<td style="text-align: center;">12.36 <span class="math inline">±</span>
43.39</td>
<td style="text-align: center;">1.55 <span class="math inline">±</span>
4.68</td>
<td style="text-align: center;">14.90 <span class="math inline">±</span>
27.82</td>
</tr>
<tr class="odd">
<td style="text-align: left;"><span
Expand Down Expand Up @@ -2573,6 +2573,28 @@ class="smallcaps">SafetyPointButton2-v0</span></td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_carcircle1_1e7.png">
<br>
<div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">
SafetyCarCircle1-v0
</div>
</td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_carcircle2_1e7.png">
<br>
<div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">
SafetyCarCircle2-v0
</div>
</td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_cargoal1_1e7.png">
Expand Down Expand Up @@ -2617,6 +2639,28 @@ class="smallcaps">SafetyPointButton2-v0</span></td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_pointcircle1_1e7.png">
<br>
<div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">
SafetyPointCircle1-v0
</div>
</td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_pointcircle2_1e7.png">
<br>
<div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">
SafetyPointCircle2-v0
</div>
</td>
</tr>
</table>
<table>
<tr>
<td style="text-align:center">
<img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="https://github.com/Gaiejj/omnisafe_benchmarks_cruve/blob/main/on-policy/benchmarks/saute_pointgoal1_1e7.png">
Expand Down
2 changes: 1 addition & 1 deletion docs/source/baserl/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ where
Finally, we obtain :math:`\beta=\sqrt{2 \delta / s^T H s}`.

.. hint::
The term :math:`s^THs` is an intermediate result produced by the conjugate gradient algorithm.
The term :math:`s^T H s` is an intermediate result produced by the conjugate gradient algorithm.

To meet the constraints, TRPO uses line search algorithm to compute the final step length.
Detailedly, TRPO performs the line search on the objective :math:`L_{{\boldsymbol{\theta}}_{\text {old }}}({\boldsymbol{\theta}})-\mathcal{X}\left[\bar{D}_{\text {KL }}\left({\boldsymbol{\theta}}_{\text {old }}, {\boldsymbol{\theta}}\right) \leq \delta\right]`, where :math:`\mathcal{X}[\ldots]` equals to :math:`0`,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/saferl/lag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ Policy update
.. attention::
:class: warning

In practice, we often need to manually set the initial value of as well as the learning rate.
In practice, we often need to manually set the initial value of :math:`\lambda` as well as the learning rate :math:`\eta_\lambda`.
Unfortunately, Lagrange algorithms are algorithms that **are sensitive to hyperparameter selection**.

- If the initial value of :math:`\lambda` or learning rate is chosen to be large,
- If the initial value of :math:`\lambda` or learning rate :math:`\eta_\lambda` is chosen to be large,
the agent may suffer from a low reward.
- Else, it may violate the constraints.

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,4 @@ DynamicsValMseLoss
UpdateActorCritic
UpdateDynamics
mathbb
meger
18 changes: 9 additions & 9 deletions examples/collect_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
# also, please make sure you have run:
# python train_policy.py --algo PPO --env ENVID
# where ENVID is the environment from which you want to collect data.
# The `PATH_TO_AGENT` is the directory path containing the `torch_save`.

ENV_NAME = 'SafetyPointCircle1-v0'
SIZE = 2_000_000
AGENTS = [
('./runs/PPO', 'epoch-500', 1_000_000),
('./runs/PPOLag', 'epoch-500', 1_000_000),
env_name = 'SafetyAntVelocity-v1'
size = 1_000_000
agents = [
('PATH_TO_AGENT', 'epoch-500.pt', 1_000_000),
]
SAVE_DIR = './data'
save_dir = './data'

if __name__ == '__main__':
col = OfflineDataCollector(SIZE, ENV_NAME)
for agent, model_name, num in AGENTS:
col = OfflineDataCollector(size, env_name)
for agent, model_name, num in agents:
col.register_agent(agent, model_name, num)
col.collect(SAVE_DIR)
col.collect(save_dir)
1 change: 1 addition & 0 deletions omnisafe/adapter/simmer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ def control_budget(self, ep_costs: torch.Tensor) -> None:
safety_budget=self._safety_budget.cpu(),
observation=ep_costs.cpu(),
).to(self._device)
self._rel_safety_budget = (self._safety_budget / self._upper_budget).to(self._device)
15 changes: 9 additions & 6 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,21 @@ def _init_checks(self) -> None:
def _init_algo(self) -> None:
"""Initialize the algorithm."""
check_all_configs(self.cfgs, self.algo_type)
device = self.cfgs.train_cfgs.device
if device == 'cpu':
torch.set_num_threads(self.cfgs.train_cfgs.torch_threads)
else:
torch.set_num_threads(1)
torch.cuda.set_device(self.cfgs.train_cfgs.device)
if distributed.fork(
self.cfgs.train_cfgs.parallel,
device=self.cfgs.train_cfgs.device,
):
# re-launches the current script with workers linked by MPI
sys.exit()
if self.cfgs.train_cfgs.device == 'cpu':
torch.set_num_threads(self.cfgs.train_cfgs.torch_threads)
else:
if self.cfgs.train_cfgs.parallel > 1 and os.getenv('MASTER_ADDR') is not None:
ddp_local_rank = int(os.environ['LOCAL_RANK'])
self.cfgs.train_cfgs.device = f'cuda:{ddp_local_rank}'
torch.set_num_threads(1)
torch.cuda.set_device(self.cfgs.train_cfgs.device)
os.environ['OMNISAFE_DEVICE'] = self.cfgs.train_cfgs.device
self.agent: BaseAlgo = registry.get(self.algo)(
env_id=self.env_id,
cfgs=self.cfgs,
Expand Down
Loading

0 comments on commit 4e54ba3

Please sign in to comment.