diff --git a/README.md b/README.md index 2b9396bad..7b7bdab87 100644 --- a/README.md +++ b/README.md @@ -128,8 +128,11 @@ OmniSafe requires Python 3.8+ and PyTorch 1.10+. ### Install from source ```bash +# Clone the repo git clone https://github.com/PKU-MARL/omnisafe cd omnisafe + +# Create a conda environment conda create -n omnisafe python=3.8 conda activate omnisafe @@ -141,7 +144,7 @@ pip install -e . ```bash cd examples -python train_policy.py --env-id SafetyPointGoal1-v0 --algo PPOLag --parallel 1 +python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1 ``` **algo:** @@ -208,29 +211,36 @@ More information about environments, please refer to [Safety Gymnasium](https:// ```python import omnisafe -env = 'SafetyPointGoal1-v0' -agent = omnisafe.Agent('PPOLag', env) +env_id = 'SafetyPointGoal1-v0' +custom_cfgs = { + 'train_cfgs': { + 'total_steps': 1024000, + 'vector_env_nums': 1, + '--parallel': 1, + }, + 'algo_cfgs': { + 'update_cycle': 2048, + 'update_iters': 1, + }, + 'logger_cfgs': { + 'use_wandb': False, + }, +} + +agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs) agent.learn() ``` -### 2. Run Agent from custom config dict - -```python -import omnisafe - -env = 'SafetyPointGoal1-v0' +### 3. Run Agent from custom terminal config -custom_dict = {'epochs': 1, 'log_dir': './runs'} -agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict) -agent.learn() -``` +You can also run agent from custom terminal config. You can set any config in corresponding yaml file. -### 3. Run Agent from custom terminal config +For example, you can run `PPOLag` agent on `SafetyPointGoal1-v0` environment with `total_steps=1024000`, `vector_env_nums=1` and `parallel=1` by: ```bash cd examples -python train_policy.py --env-id SafetyPointGoal1-v0 --algo PPOLag --parallel 1 +python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1 ``` ### 4. Evalutate Saved Policy diff --git a/examples/train_from_custom_dict.py b/examples/train_from_custom_dict.py index d1d6f770c..816c35ffb 100644 --- a/examples/train_from_custom_dict.py +++ b/examples/train_from_custom_dict.py @@ -14,44 +14,24 @@ # ============================================================================== """Example of training a policy from custom dict with OmniSafe.""" -import argparse - import omnisafe -parser = argparse.ArgumentParser() -env_id = 'SafetyHumanoidVelocity-v4' -parser.add_argument( - '--parallel', - default=1, - type=int, - metavar='N', - help='Number of paralleled progress for calculations.', -) +env_id = 'SafetyPointGoal1-v0' custom_cfgs = { 'train_cfgs': { - 'total_steps': 1000, + 'total_steps': 1024000, + 'vector_env_nums': 1, + '--parallel': 1, }, 'algo_cfgs': { - 'update_cycle': 1000, + 'update_cycle': 2048, 'update_iters': 1, }, 'logger_cfgs': { 'use_wandb': False, }, - 'env_cfgs': { - 'vector_env_nums': 1, - }, } -args, _ = parser.parse_known_args() -agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs, parallel=args.parallel) -agent.learn() -# obs = env.reset() -# for i in range(1000): -# action, _states = agent.predict(obs, deterministic=True) -# obs, reward, done, info = env.step(action) -# env.render() -# if done: -# obs = env.reset() -# env.close() +agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs) +agent.learn() diff --git a/examples/train_from_yaml.py b/examples/train_from_yaml.py index 7c856cba9..03641150e 100644 --- a/examples/train_from_yaml.py +++ b/examples/train_from_yaml.py @@ -16,16 +16,7 @@ import omnisafe -env = omnisafe.Env('SafetyPointGoal1-v0') +env_id = 'SafetyPointGoal1-v0' -agent = omnisafe.Agent('PPOLag', env) +agent = omnisafe.Agent('PPOLag', env_id) agent.learn() - -# obs = env.reset() -# for i in range(1000): -# action, _states = agent.predict(obs, deterministic=True) -# obs, reward, cost, done, info = env.step(action) -# env.render() -# if done: -# obs = env.reset() -# env.close()