Skip to content

Commit

Permalink
refactor(wrapper): refactor the cuda setting (#176)
Browse files Browse the repository at this point in the history
* refactor(wrapper): refactor the cuda setting

* chore: revert train_policy.py

* chore: set device in safety_gymnasium_env.py

* fix: [pre-commit.ci] auto fixes [...]

* fix(safety_gymnasium_env.py): fix device interface

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and zmsn-2077 committed Mar 26, 2023
1 parent 0e8114a commit 3bf7660
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 95 deletions.
16 changes: 6 additions & 10 deletions omnisafe/adapter/offpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# ==============================================================================
"""OffPolicy Adapter for OmniSafe."""

from functools import partial
from typing import Dict, Optional

import torch
from gymnasium import spaces

from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.buffer import VectorOffPolicyBuffer
Expand Down Expand Up @@ -60,16 +58,14 @@ def roll_out( # pylint: disable=too-many-locals
logger (Logger): Logger.
use_rand_action (bool): Whether to use random action.
"""
if use_rand_action:
if isinstance(self._env.action_space, spaces.Box):
act_fn = partial(
torch.rand, size=(self._env.num_envs, *self._env.action_space.shape)
for _ in range(roll_out_step):
if use_rand_action:
act = torch.as_tensor(self._env.sample_action(), dtype=torch.float32).to(
self._device
)
else:
act_fn = partial(agent.step, self._current_obs, deterministic=False)
else:
act = agent.step(self._current_obs, deterministic=False)

for _ in range(roll_out_step):
act = act_fn()
next_obs, reward, cost, terminated, truncated, info = self.step(act)

self._log_value(reward=reward, cost=cost, info=info)
Expand Down
33 changes: 12 additions & 21 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,35 @@ def __init__( # pylint: disable=too-many-arguments
assert env_id in support_envs(), f'Env {env_id} is not supported.'

self._env_id = env_id
self._env = make(env_id, num_envs=num_envs)
self._env = make(env_id, num_envs=num_envs, device=cfgs.train_cfgs.device)
self._cfgs = cfgs
self._device = cfgs.train_cfgs.device
self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
cost_normalize=cfgs.algo_cfgs.cost_normalize,
)
self._env.set_seed(seed)

self._cfgs = cfgs
self._device = cfgs.train_cfgs.device

def _wrapper(
self,
obs_normalize: bool = True,
reward_normalize: bool = True,
cost_normalize: bool = True,
):
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000)
self._env = TimeLimit(self._env, device=self._device, time_limit=1000)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env)
self._env = AutoReset(self._env, device=self._device)
if obs_normalize:
self._env = ObsNormalize(self._env)
self._env = ObsNormalize(self._env, device=self._device)
if reward_normalize:
self._env = RewardNormalize(self._env)
self._env = RewardNormalize(self._env, device=self._device)
if cost_normalize:
self._env = CostNormalize(self._env)
self._env = ActionScale(self._env, low=-1.0, high=1.0)
self._env = CostNormalize(self._env, device=self._device)
self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env)
self._env = Unsqueeze(self._env, device=self._device)

@property
def action_space(self) -> OmnisafeSpace:
Expand Down Expand Up @@ -111,14 +110,7 @@ def step(
truncated (torch.Tensor): whether the episode has been truncated due to a time limit.
info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
obs, reward, cost, terminated, truncated, info = self._env.step(action)
obs, reward, cost, terminated, truncated = map(
lambda x: x.to(self._device),
(obs, reward, cost, terminated, truncated),
)
if info.get('final_observation') is not None:
info['final_observation'] = info['final_observation'].to(self._device)
return obs, reward, cost, terminated, truncated, info
return self._env.step(action)

def reset(self) -> Tuple[torch.Tensor, Dict]:
"""Resets the environment and returns an initial observation.
Expand All @@ -130,8 +122,7 @@ def reset(self) -> Tuple[torch.Tensor, Dict]:
observation (torch.Tensor): the initial observation of the space.
info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
obs, info = self._env.reset()
return obs.to(self._device), info
return self._env.reset()

def save(self) -> Dict[str, torch.nn.Module]:
"""Save the environment.
Expand Down
10 changes: 5 additions & 5 deletions omnisafe/adapter/saute_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ def _wrapper(
cost_normalize: bool = False,
):
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000)
self._env = TimeLimit(self._env, device=self._device, time_limit=1000)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env)
self._env = AutoReset(self._env, device=self._device)
if obs_normalize:
self._env = ObsNormalize(self._env)
self._env = ObsNormalize(self._env, device=self._device)
assert reward_normalize is False, 'Reward normalization is not supported'
assert cost_normalize is False, 'Cost normalization is not supported'
self._env = ActionScale(self._env, low=-1.0, high=1.0)
self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env)
self._env = Unsqueeze(self._env, device=self._device)

def reset(self) -> Tuple[torch.Tensor, Dict]:
obs, info = self._env.reset()
Expand Down
41 changes: 23 additions & 18 deletions omnisafe/algorithms/on_policy/second_order/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import Tuple

import numpy as np
import torch

from omnisafe.algorithms import registry
Expand Down Expand Up @@ -116,8 +115,12 @@ def _cpo_search_step(
acceptance_step = step + 1

with torch.no_grad():
# loss of policy reward from target/expected reward
loss_reward, _ = self._loss_pi(obs=obs, act=act, logp=logp, adv=adv_r)
try:
# loss of policy reward from target/expected reward
loss_reward, _ = self._loss_pi(obs=obs, act=act, logp=logp, adv=adv_r)
except ValueError:
step_frac *= decay
continue
# loss of cost of policy cost from real/expected reward
loss_cost = self._loss_pi_cost(obs=obs, act=act, logp=logp, adv_c=adv_c)
# compute KL distance between new and old policy
Expand All @@ -139,7 +142,10 @@ def _cpo_search_step(
# check whether there are nan.
if not torch.isfinite(loss_reward) and not torch.isfinite(loss_cost):
self._logger.log('WARNING: loss_pi not finite')
elif loss_reward_improve < 0 if optim_case > 1 else False:
if not torch.isfinite(kl):
self._logger.log('WARNING: KL not finite')
continue
if loss_reward_improve < 0 if optim_case > 1 else False:
self._logger.log('INFO: did not improve improve <0')
# change of cost's range
elif loss_cost_diff > max(-violation_c, 0):
Expand Down Expand Up @@ -236,14 +242,13 @@ def _update_actor(

b_grad = get_flat_gradients_from(self._actor_critic.actor)
ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit
cost = ep_costs / (self._logger.get_stats('Metrics/EpLen')[0] + 1e-8)

p = conjugate_gradients(self._fvp, b_grad, self._cfgs.algo_cfgs.cg_iters)
q = xHx
r = grad.dot(p)
s = b_grad.dot(p)

if b_grad.dot(b_grad) <= 1e-6 and cost < 0:
if b_grad.dot(b_grad) <= 1e-6 and ep_costs < 0:
# feasible step and cost grad is zero: use plain TRPO update...
A = torch.zeros(1)
B = torch.zeros(1)
Expand All @@ -253,17 +258,17 @@ def _update_actor(
assert torch.isfinite(s).all(), 's is not finite'

A = q - r**2 / (s + 1e-8)
B = 2 * self._cfgs.algo_cfgs.target_kl - cost**2 / (s + 1e-8)
B = 2 * self._cfgs.algo_cfgs.target_kl - ep_costs**2 / (s + 1e-8)

if cost < 0 and B < 0:
if ep_costs < 0 and B < 0:
# point in trust region is feasible and safety boundary doesn't intersect
# ==> entire trust region is feasible
optim_case = 3
elif cost < 0 <= B:
elif ep_costs < 0 <= B:
# point in trust region is feasible but safety boundary intersects
# ==> only part of trust region is feasible
optim_case = 2
elif cost >= 0 and B >= 0:
elif ep_costs >= 0 and B >= 0:
# point in trust region is infeasible and cost boundary doesn't intersect
# ==> entire trust region is infeasible
optim_case = 1
Expand Down Expand Up @@ -296,16 +301,16 @@ def project(data: torch.Tensor, low: float, high: float) -> torch.Tensor:
# where projection(str,b,c)=max(b,min(str,c))
# may be regarded as a projection from effective region towards safety region
r_num = r.item()
eps_cost = cost + 1e-8
if cost < 0:
eps_cost = ep_costs + 1e-8
if ep_costs < 0:
lambda_a_star = project(lambda_a, 0.0, r_num / eps_cost)
lambda_b_star = project(lambda_b, r_num / eps_cost, np.inf)
lambda_b_star = project(lambda_b, r_num / eps_cost, torch.inf)
else:
lambda_a_star = project(lambda_a, r_num / eps_cost, np.inf)
lambda_a_star = project(lambda_a, r_num / eps_cost, torch.inf)
lambda_b_star = project(lambda_b, 0.0, r_num / eps_cost)

def f_a(lam):
return -0.5 * (A / (lam + 1e-8) + B * lam) - r * cost / (s + 1e-8)
return -0.5 * (A / (lam + 1e-8) + B * lam) - r * ep_costs / (s + 1e-8)

def f_b(lam):
return -0.5 * (q / (lam + 1e-8) + 2 * self._cfgs.algo_cfgs.target_kl * lam)
Expand All @@ -316,15 +321,15 @@ def f_b(lam):

# discard all negative values with torch.clamp(x, min=0)
# Nu_star = (lambda_star * - r)/s
nu_star = torch.clamp(lambda_star * cost - r, min=0) / (s + 1e-8)
nu_star = torch.clamp(lambda_star * ep_costs - r, min=0) / (s + 1e-8)
# final x_star as final direction played as policy's loss to backward and update
step_direction = 1.0 / (lambda_star + 1e-8) * (x - nu_star * p)

else: # case == 0
# purely decrease costs
# without further check
lambda_star = torch.zeros(1)
nu_star = np.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (s + 1e-8))
nu_star = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (s + 1e-8))
step_direction = -nu_star * p

step_direction, accept_step = self._cpo_search_step(
Expand All @@ -339,7 +344,7 @@ def f_b(lam):
loss_reward_before=loss_reward_before,
loss_cost_before=loss_cost_before,
total_steps=20,
violation_c=cost,
violation_c=ep_costs,
optim_case=optim_case,
)

Expand Down
9 changes: 4 additions & 5 deletions omnisafe/algorithms/on_policy/second_order/pcpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ def _update_actor(

b_grad = get_flat_gradients_from(self._actor_critic.actor)
ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit
cost = ep_costs / (self._logger.get_stats('Metrics/EpLen')[0] + 1e-8)

self._logger.log(f'c = {cost}')
self._logger.log(f'c = {ep_costs}')
self._logger.log(f'b^T b = {b_grad.dot(b_grad).item()}')

p = conjugate_gradients(self._fvp, b_grad, self._cfgs.algo_cfgs.cg_iters)
Expand All @@ -104,7 +103,7 @@ def _update_actor(
step_direction = (
torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (q + 1e-8)) * H_inv_g
- torch.clamp_min(
(torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / q) * r + cost) / s,
(torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / q) * r + ep_costs) / s,
torch.tensor(0.0, device=self._device),
)
* p
Expand All @@ -121,8 +120,8 @@ def _update_actor(
adv_c=adv_c,
loss_reward_before=loss_reward_before,
loss_cost_before=loss_cost_before,
total_steps=20,
violation_c=cost,
total_steps=200,
violation_c=ep_costs,
)
theta_new = theta_old + step_direction
set_param_values_to_model(self._actor_critic.actor, theta_new)
Expand Down
3 changes: 2 additions & 1 deletion omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,14 @@ class Wrapper(CMDP):
"""

def __init__(self, env: CMDP) -> None:
def __init__(self, env: CMDP, device: torch.device) -> None:
"""Initialize the wrapper.
Args:
env (CMDP): the environment.
"""
self._env = env
self._device = device

def __getattr__(self, name: str) -> Any:
"""Get the attribute of the environment.
Expand Down
19 changes: 13 additions & 6 deletions omnisafe/envs/safety_gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class SafetyGymnasiumEnv(CMDP):
need_auto_reset_wrapper = False
need_time_limit_wrapper = False

def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None:
def __init__(
self, env_id: str, num_envs: int = 1, device: torch.device = torch.device('cpu'), **kwargs
) -> None:
super().__init__(env_id)
if num_envs > 1:
self._env = safety_gymnasium.vector.make(env_id=env_id, num_envs=num_envs, **kwargs)
Expand All @@ -88,13 +90,16 @@ def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None:

self._num_envs = num_envs
self._metadata = self._env.metadata
self._device = device

def step(
self, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
obs, reward, cost, terminated, truncated, info = self._env.step(action)
obs, reward, cost, terminated, truncated, info = self._env.step(
action.detach().cpu().numpy()
)
obs, reward, cost, terminated, truncated = map(
lambda x: torch.as_tensor(x, dtype=torch.float32),
lambda x: torch.as_tensor(x, dtype=torch.float32, device=self._device),
(obs, reward, cost, terminated, truncated),
)
if 'final_observation' in info:
Expand All @@ -105,20 +110,22 @@ def step(
]
)
info['final_observation'] = torch.as_tensor(
info['final_observation'], dtype=torch.float32
info['final_observation'], dtype=torch.float32, device=self._device
)

return obs, reward, cost, terminated, truncated, info

def reset(self, seed: Optional[int] = None) -> Tuple[torch.Tensor, Dict]:
obs, info = self._env.reset(seed=seed)
return torch.as_tensor(obs, dtype=torch.float32), info
return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info

def set_seed(self, seed: int) -> None:
self.reset(seed=seed)

def sample_action(self) -> torch.Tensor:
return torch.as_tensor(self._env.action_space.sample(), dtype=torch.float32)
return torch.as_tensor(
self._env.action_space.sample(), dtype=torch.float32, device=self._device
)

def render(self) -> Any:
return self._env.render()
Expand Down
Loading

0 comments on commit 3bf7660

Please sign in to comment.