Skip to content

Commit

Permalink
docs: update README.md (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Mar 14, 2023
1 parent 8058f78 commit b5f785e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 53 deletions.
40 changes: 25 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ OmniSafe requires Python 3.8+ and PyTorch 1.10+.
### Install from source

```bash
# Clone the repo
git clone https://github.com/PKU-MARL/omnisafe
cd omnisafe

# Create a conda environment
conda create -n omnisafe python=3.8
conda activate omnisafe

Expand All @@ -141,7 +144,7 @@ pip install -e .

```bash
cd examples
python train_policy.py --env-id SafetyPointGoal1-v0 --algo PPOLag --parallel 1
python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1
```

**algo:**
Expand Down Expand Up @@ -208,29 +211,36 @@ More information about environments, please refer to [Safety Gymnasium](https://
```python
import omnisafe

env = 'SafetyPointGoal1-v0'

agent = omnisafe.Agent('PPOLag', env)
env_id = 'SafetyPointGoal1-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 1024000,
'vector_env_nums': 1,
'--parallel': 1,
},
'algo_cfgs': {
'update_cycle': 2048,
'update_iters': 1,
},
'logger_cfgs': {
'use_wandb': False,
},
}

agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs)
agent.learn()
```

### 2. Run Agent from custom config dict

```python
import omnisafe

env = 'SafetyPointGoal1-v0'
### 3. Run Agent from custom terminal config

custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict)
agent.learn()
```
You can also run agent from custom terminal config. You can set any config in corresponding yaml file.

### 3. Run Agent from custom terminal config
For example, you can run `PPOLag` agent on `SafetyPointGoal1-v0` environment with `total_steps=1024000`, `vector_env_nums=1` and `parallel=1` by:

```bash
cd examples
python train_policy.py --env-id SafetyPointGoal1-v0 --algo PPOLag --parallel 1
python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1
```

### 4. Evalutate Saved Policy
Expand Down
34 changes: 7 additions & 27 deletions examples/train_from_custom_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,24 @@
# ==============================================================================
"""Example of training a policy from custom dict with OmniSafe."""

import argparse

import omnisafe


parser = argparse.ArgumentParser()
env_id = 'SafetyHumanoidVelocity-v4'
parser.add_argument(
'--parallel',
default=1,
type=int,
metavar='N',
help='Number of paralleled progress for calculations.',
)
env_id = 'SafetyPointGoal1-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 1000,
'total_steps': 1024000,
'vector_env_nums': 1,
'--parallel': 1,
},
'algo_cfgs': {
'update_cycle': 1000,
'update_cycle': 2048,
'update_iters': 1,
},
'logger_cfgs': {
'use_wandb': False,
},
'env_cfgs': {
'vector_env_nums': 1,
},
}
args, _ = parser.parse_known_args()
agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs, parallel=args.parallel)
agent.learn()

# obs = env.reset()
# for i in range(1000):
# action, _states = agent.predict(obs, deterministic=True)
# obs, reward, done, info = env.step(action)
# env.render()
# if done:
# obs = env.reset()
# env.close()
agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs)
agent.learn()
13 changes: 2 additions & 11 deletions examples/train_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,7 @@
import omnisafe


env = omnisafe.Env('SafetyPointGoal1-v0')
env_id = 'SafetyPointGoal1-v0'

agent = omnisafe.Agent('PPOLag', env)
agent = omnisafe.Agent('PPOLag', env_id)
agent.learn()

# obs = env.reset()
# for i in range(1000):
# action, _states = agent.predict(obs, deterministic=True)
# obs, reward, cost, done, info = env.step(action)
# env.render()
# if done:
# obs = env.reset()
# env.close()

0 comments on commit b5f785e

Please sign in to comment.