Skip to content

Commit

Permalink
Merge 538b14a into 5085404
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored May 25, 2023
2 parents 5085404 + 538b14a commit ed410c0
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 20 deletions.
8 changes: 4 additions & 4 deletions omnisafe/adapter/early_terminated_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
super().__init__(env_id, num_envs, seed, cfgs)

self._cost_limit: float = cfgs.algo_cfgs.cost_limit
self._cost_logger: torch.Tensor = torch.zeros(self._env.num_envs)
self._cost_logger: torch.Tensor = torch.zeros(self._env.num_envs).to(self._device)

def step(
self,
Expand Down Expand Up @@ -79,9 +79,9 @@ def step(
self._cost_logger += info.get('original_cost', cost)

if self._cost_logger > self._cost_limit:
reward = torch.zeros(self._env.num_envs)
terminated = torch.ones(self._env.num_envs)
reward = torch.zeros(self._env.num_envs).to(self._device)
terminated = torch.ones(self._env.num_envs).to(self._device)
next_obs, _ = self._env.reset()
self._cost_logger = torch.zeros(self._env.num_envs)
self._cost_logger = torch.zeros(self._env.num_envs).to(self._device)

return next_obs, reward, cost, terminated, truncated, info
6 changes: 3 additions & 3 deletions omnisafe/adapter/saute_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
/ (1 - self._cfgs.algo_cfgs.saute_gamma)
/ self._cfgs.algo_cfgs.max_ep_len
* torch.ones(num_envs, 1)
)
).to(self._device)

assert isinstance(self._env.observation_space, Box), 'Observation space must be Box'
self._observation_space: Box = Box(
Expand Down Expand Up @@ -120,7 +120,7 @@ def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
info: Some information logged by the environment.
"""
obs, info = self._env.reset()
self._safety_obs = torch.ones(self._env.num_envs, 1)
self._safety_obs = torch.ones(self._env.num_envs, 1).to(self._device)
obs = self._augment_obs(obs)
return obs, info

Expand Down Expand Up @@ -236,7 +236,7 @@ def _reset_log(self, idx: int | None = None) -> None:
"""
super()._reset_log(idx)
if idx is None:
self._ep_budget = torch.zeros(self._env.num_envs)
self._ep_budget = torch.zeros(self._env.num_envs).to(self._device)
else:
self._ep_budget[idx] = 0

Expand Down
16 changes: 9 additions & 7 deletions omnisafe/adapter/simmer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,17 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
/ (1 - self._cfgs.algo_cfgs.saute_gamma)
/ self._cfgs.algo_cfgs.max_ep_len
* torch.ones(num_envs, 1)
)
).to(self._device)
self._upper_budget: torch.Tensor = (
self._cfgs.algo_cfgs.upper_budget
* (1 - self._cfgs.algo_cfgs.saute_gamma**self._cfgs.algo_cfgs.max_ep_len)
/ (1 - self._cfgs.algo_cfgs.saute_gamma)
/ self._cfgs.algo_cfgs.max_ep_len
* torch.ones(num_envs, 1)
).to(self._device)
self._rel_safety_budget: torch.Tensor = (self._safety_budget / self._upper_budget).to(
self._device,
)
self._rel_safety_budget: torch.Tensor = self._safety_budget / self._upper_budget

assert isinstance(self._env.observation_space, Box), 'Observation space must be Box'
self._observation_space: Box = Box(
Expand All @@ -83,7 +85,7 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
)
self._controller: BaseSimmerAgent = SimmerPIDAgent(
cfgs=cfgs.control_cfgs,
budget_bound=self._upper_budget,
budget_bound=self._upper_budget.cpu(),
)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
Expand All @@ -98,7 +100,7 @@ def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
info: Some information logged by the environment.
"""
obs, info = self._env.reset()
self._safety_obs = self._rel_safety_budget * torch.ones(self._num_envs, 1)
self._safety_obs = self._rel_safety_budget * torch.ones(self._num_envs, 1).to(self._device)
obs = self._augment_obs(obs)
return obs, info

Expand All @@ -115,6 +117,6 @@ def control_budget(self, ep_costs: torch.Tensor) -> None:
/ self._cfgs.algo_cfgs.max_ep_len
)
self._safety_budget = self._controller.act(
safety_budget=self._safety_budget,
observation=ep_costs,
)
safety_budget=self._safety_budget.cpu(),
observation=ep_costs.cpu(),
).to(self._device)
2 changes: 2 additions & 0 deletions omnisafe/common/experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def run(
hashed_exp_name = var['env_id'][:30] + '---' + hash_string(exp_name)
exp_names.append(':'.join((hashed_exp_name[:5], exp_name)))
exp_log_dir = os.path.join(self.log_dir, hashed_exp_name, '')
if not var.get('logger_cfgs'):
var['logger_cfgs'] = {'log_dir': './exp'}
var['logger_cfgs'].update({'log_dir': exp_log_dir})
self.save_same_exps_config(exp_log_dir, var)
results.append(pool.submit(thunk, str(idx), var['algo'], var['env_id'], var))
Expand Down
3 changes: 2 additions & 1 deletion omnisafe/common/statistics_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def make_config_groups(
# value of parameter is determined above
group_config.pop(parameter)
# seed is not a parameter
group_config.pop('seed')
if 'seed' in group_config:
group_config.pop('seed')
if 'train_cfgs' in group_config:
group_config['train_cfgs'].pop('device', None)
# combine all possible combinations of other parameters
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/models/actor/perturbation_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__( # pylint: disable=too-many-arguments

self.vae = VAE(obs_space, act_space, hidden_sizes, activation, weight_initialization_mode)
self.perturbation = build_mlp_network(
sizes=[self._obs_dim + self._act_dim, *hidden_sizes] + [self._act_dim],
sizes=[self._obs_dim + self._act_dim, *hidden_sizes, self._act_dim],
activation=activation,
output_activation='tanh',
weight_initialization_mode=weight_initialization_mode,
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/models/actor/vae_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def __init__( # pylint: disable=too-many-arguments
self._latent_dim = self._act_dim * 2

self._encoder = build_mlp_network(
sizes=[self._obs_dim + self._act_dim, *hidden_sizes] + [self._latent_dim * 2],
sizes=[self._obs_dim + self._act_dim, *hidden_sizes, self._latent_dim * 2],
activation=activation,
weight_initialization_mode=weight_initialization_mode,
)
self._decoder = build_mlp_network(
sizes=[self._obs_dim + self._latent_dim, *hidden_sizes] + [self._act_dim],
sizes=[self._obs_dim + self._latent_dim, *hidden_sizes, self._act_dim],
activation=activation,
weight_initialization_mode=weight_initialization_mode,
)
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/models/critic/q_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
critic = nn.Sequential(obs_encoder, net)
else:
net = build_mlp_network(
[self._obs_dim + self._act_dim, *hidden_sizes] + [1],
[self._obs_dim + self._act_dim, *hidden_sizes, 1],
activation=activation,
weight_initialization_mode=weight_initialization_mode,
)
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/models/offline/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
self.activation = activation
self.hidden_sizes = hidden_sizes
self.net = build_mlp_network(
[self._obs_dim, *list(hidden_sizes)] + [self._out_dim],
[self._obs_dim, *list(hidden_sizes), self._out_dim],
activation=activation,
weight_initialization_mode=weight_initialization_mode,
)
Expand Down

0 comments on commit ed410c0

Please sign in to comment.