Skip to content

Commit

Permalink
style: fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed May 1, 2024
1 parent c7c8398 commit bb6e62a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 14 deletions.
4 changes: 2 additions & 2 deletions omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def _log_value(
cost (torch.Tensor): The immediate step cost.
info (dict[str, Any]): Some information logged by the environment.
"""
self._ep_ret += info.get('original_reward', reward.mean()).cpu()
self._ep_cost += info.get('original_cost', cost.mean()).cpu()
self._ep_ret += info.get('original_reward', reward).cpu()
self._ep_cost += info.get('original_cost', cost).cpu()
self._ep_len += 1

def _log_metrics(self, logger: Logger, idx: int) -> None:
Expand Down
10 changes: 5 additions & 5 deletions omnisafe/common/control_barrier_function/crabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def training_step(self, batch):
self.log(f'model/{i}/training_loss', loss.item())

opt = self.optimizers()
opt.zero_grad()
opt.zero_grad() # type: ignore

self.manual_backward(total_loss)
self.manual_backward(total_loss) # type: ignore
nn.utils.clip_grad_norm_(self.parameters(), 10)
opt.step()
opt.step() # type: ignore

def validation_step(self, batch):
"""Validation step of the ensemble model.
Expand Down Expand Up @@ -291,10 +291,10 @@ def training_step(self, batch):
self.log(f'{self.name}/training_loss', loss.item(), on_step=False, on_epoch=True)

opt = self.optimizers()
opt.zero_grad()
opt.zero_grad() # type: ignore
self.manual_backward(loss, opt)
nn.utils.clip_grad_norm_(self.parameters(), 10)
opt.step()
opt.step() # type: ignore

return {
'loss': loss.item(),
Expand Down
3 changes: 3 additions & 0 deletions omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CMDP(ABC):
Attributes:
need_time_limit_wrapper (bool): Whether the environment need time limit wrapper.
need_auto_reset_wrapper (bool): Whether the environment need auto reset wrapper.
need_evaluation (bool): Whether to create an instance of environment for evaluation.
"""

_action_space: OmnisafeSpace
Expand Down Expand Up @@ -194,6 +195,8 @@ class Wrapper(CMDP):
Attributes:
_env (CMDP): The environment.
_device (torch.device): The device to use. Defaults to ``torch.device('cpu')``.
need_evaluation (bool): Whether to create an instance of environment for evaluation.
"""

def __init__(self, env: CMDP, device: torch.device = DEVICE_CPU) -> None:
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/envs/safety_isaac_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Environments of Safe Isaac Gym the Safety-Gymnasium."""
"""Environments of Safe Isaac Gym in the Safety-Gymnasium."""

from __future__ import annotations

Expand Down
20 changes: 14 additions & 6 deletions omnisafe/utils/isaac_gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ def step(
dict[str, Any],
]:
"""Step the environment."""
obs, rews, costs, terminated, infos = super().step(action)
obs, rews, costs, terminated, infos = super().step(action.unsqueeze(0))
truncated = terminated
return obs, rews, costs, terminated, truncated, infos
return (
obs.squeeze(0),
rews.squeeze(0),
costs.squeeze(0),
terminated.squeeze(0),
truncated.squeeze(0),
infos,
)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment."""
obs = super().reset()
return obs, {}
return obs.squeeze(0), {}


def parse_sim_params(args: argparse.Namespace) -> gymapi.SimParams:
Expand All @@ -85,8 +92,8 @@ def parse_sim_params(args: argparse.Namespace) -> gymapi.SimParams:
sim_params.physx.num_subscenes = args.subscenes
sim_params.physx.max_gpu_contact_pairs = 8 * 1024 * 1024

sim_params.use_gpu_pipeline = args.use_gpu_pipeline
sim_params.physx.use_gpu = args.use_gpu
sim_params.use_gpu_pipeline = args.use_gpu_pipeline if args.device != 'cpu' else False
sim_params.physx.use_gpu = args.use_gpu if args.device != 'cpu' else False

if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0:
sim_params.physx.num_threads = args.num_threads
Expand Down Expand Up @@ -120,7 +127,7 @@ def make_isaac_gym_env(
{'name': '--torch-threads', 'type': int, 'default': 16},
]
args = gymutil.parse_arguments(custom_parameters=custom_parameters)
args.device = args.sim_device_type if args.use_gpu_pipeline else 'cpu'
args.device = args.sim_device_type if args.use_gpu_pipeline and args.device != 'cpu' else 'cpu'
sim_params = parse_sim_params(args=args)

device_id = int(str(device).rsplit(':', maxsplit=1)[-1]) if str(device) != 'cpu' else 0
Expand All @@ -136,6 +143,7 @@ def make_isaac_gym_env(
task_fn = ShadowHandOverSafeJoint
else:
raise NotImplementedError

task = task_fn(
num_envs=num_envs,
sim_params=sim_params,
Expand Down

0 comments on commit bb6e62a

Please sign in to comment.