Skip to content

Commit

Permalink
feat: support isaacgym interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed Apr 21, 2024
1 parent 94fba98 commit dcdc3bf
Show file tree
Hide file tree
Showing 11 changed files with 577 additions and 66 deletions.
6 changes: 6 additions & 0 deletions omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
# ==============================================================================
"""OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning."""

from contextlib import suppress


with suppress(ImportError):
from isaacgym import gymutil

from omnisafe import algorithms
from omnisafe.algorithms import ALGORITHMS
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent
Expand Down
47 changes: 35 additions & 12 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ def __init__( # pylint: disable=too-many-arguments
env_cfgs = self._cfgs.env_cfgs.todict()

self._env: CMDP = make(env_id, num_envs=num_envs, device=self._device, **env_cfgs)
self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device, **env_cfgs)

self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
cost_normalize=cfgs.algo_cfgs.cost_normalize,
)

self._eval_env: CMDP | None = None
if self._env.need_evaluation:
self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device, **env_cfgs)
self._wrapper_eval(obs_normalize=cfgs.algo_cfgs.obs_normalize)

self._env.set_seed(seed)

def _wrapper(
Expand Down Expand Up @@ -116,33 +119,53 @@ def _wrapper(
"""
if self._env.need_time_limit_wrapper:
assert (
self._env.max_episode_steps and self._eval_env.max_episode_steps
self._env.max_episode_steps
), 'You must define max_episode_steps as an integer\
or cancel the use of the time_limit wrapper.'
\nor cancel the use of the time_limit wrapper.'
self._env = TimeLimit(
self._env,
time_limit=self._env.max_episode_steps,
device=self._device,
)
self._eval_env = TimeLimit(
self._eval_env,
time_limit=self._eval_env.max_episode_steps,
device=self._device,
)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env, device=self._device)
self._eval_env = AutoReset(self._eval_env, device=self._device)
if obs_normalize:
self._env = ObsNormalize(self._env, device=self._device)
self._eval_env = ObsNormalize(self._eval_env, device=self._device)
if reward_normalize:
self._env = RewardNormalize(self._env, device=self._device)
if cost_normalize:
self._env = CostNormalize(self._env, device=self._device)
self._env = ActionScale(self._env, low=-1.0, high=1.0, device=self._device)
self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env, device=self._device)

def _wrapper_eval(
self,
obs_normalize: bool = True,
) -> None:
"""Wrapper the environment for evaluation.
Args:
obs_normalize (bool, optional): Whether to normalize the observation. Defaults to True.
reward_normalize (bool, optional): Whether to normalize the reward. Defaults to True.
cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True.
"""
assert self._eval_env, 'Your environment for evaluation does not exist!'
if self._env.need_time_limit_wrapper:
assert (
self._eval_env.max_episode_steps
), 'You must define max_episode_steps as an\
\ninteger or cancel the use of the time_limit wrapper.'
self._eval_env = TimeLimit(
self._eval_env,
time_limit=self._eval_env.max_episode_steps,
device=self._device,
)
if self._env.need_auto_reset_wrapper:
self._eval_env = AutoReset(self._eval_env, device=self._device)
if obs_normalize:
self._eval_env = ObsNormalize(self._eval_env, device=self._device)
self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device)
self._eval_env = Unsqueeze(self._eval_env, device=self._device)

@property
Expand Down
3 changes: 0 additions & 3 deletions omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def rollout( # pylint: disable=too-many-locals
last_value_c = torch.zeros(1)
if not done:
if epoch_end:
logger.log(
f'Warning: trajectory cut off when rollout by epoch at {self._ep_len[idx]} steps.',
)
_, last_value_r, last_value_c, _ = agent.step(obs[idx])
if time_out:
_, last_value_r, last_value_c, _ = agent.step(
Expand Down
3 changes: 2 additions & 1 deletion omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _init_checks(self) -> None:
assert self.cfgs.train_cfgs.parallel > 0, 'parallel must be greater than 0!'
assert (
self.env_id in support_envs()
), f"{self.env_id} doesn't exist. Please choose from {support_envs()}."
), f"{self.env_id} doesn't exist. Please choose from {support_envs()}.\
\nIf you are using Safe Isaac Gym environments, please install Isaac Gym first."

def _init_algo(self) -> None:
"""Initialize the algorithm."""
Expand Down
156 changes: 156 additions & 0 deletions omnisafe/configs/on-policy/PPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,159 @@ defaults:
activation: tanh
# learning rate
lr: 0.0003

ShadowHandCatchOver2UnderarmSafeFinger:
# training configurations
train_cfgs:
# number of vectorized environments
vector_env_nums: 256
# total number of steps to train
total_steps: 100000000
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
steps_per_epoch: 32768
# number of iterations to update the policy
update_iters: 8
# batch size for each iteration
batch_size: 8192
# target kl divergence
target_kl: 0.016
# max gradient norm
max_grad_norm: 1.0
# use critic norm
use_critic_norm: False
# reward discount factor
gamma: 0.96
# normalize reward
reward_normalize: False
# normalize cost
cost_normalize: False
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor network configurations
actor:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
critic:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
# learning rate
lr: 0.0006

ShadowHandOverSafeFinger:
# training configurations
train_cfgs:
# number of vectorized environments
vector_env_nums: 256
# total number of steps to train
total_steps: 100000000
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
steps_per_epoch: 38400
# number of iterations to update the policy
update_iters: 8
# batch size for each iteration
batch_size: 8192
# target kl divergence
target_kl: 0.016
# max gradient norm
max_grad_norm: 1.0
# use critic norm
use_critic_norm: False
# reward discount factor
gamma: 0.96
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor network configurations
actor:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
critic:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
# learning rate
lr: 0.0006

ShadowHandCatchOver2UnderarmSafeJoint:
# training configurations
train_cfgs:
# number of vectorized environments
vector_env_nums: 256
# total number of steps to train
total_steps: 100000000
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
steps_per_epoch: 32768
# number of iterations to update the policy
update_iters: 8
# batch size for each iteration
batch_size: 8192
# target kl divergence
target_kl: 0.016
# max gradient norm
max_grad_norm: 1.0
# use critic norm
use_critic_norm: False
# reward discount factor
gamma: 0.96
# normalize reward
reward_normalize: False
# normalize cost
cost_normalize: False
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor network configurations
actor:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
critic:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
# learning rate
lr: 0.0006

ShadowHandOverSafeJoint:
# training configurations
train_cfgs:
# number of vectorized environments
vector_env_nums: 256
# total number of steps to train
total_steps: 100000000
# algorithm configurations
algo_cfgs:
# number of steps to update the policy
steps_per_epoch: 38400
# number of iterations to update the policy
update_iters: 8
# batch size for each iteration
batch_size: 8192
# target kl divergence
target_kl: 0.016
# max gradient norm
max_grad_norm: 1.0
# use critic norm
use_critic_norm: False
# reward discount factor
gamma: 0.96
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor network configurations
actor:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
critic:
# hidden layer sizes
hidden_sizes: [1024, 1024, 512]
# learning rate
lr: 0.0006
Loading

0 comments on commit dcdc3bf

Please sign in to comment.