Skip to content

Commit

Permalink
feat: support policy evaluation (PKU-Alignment#137)
Browse files Browse the repository at this point in the history
Co-authored-by: ruiyang sun <[email protected]>
  • Loading branch information
2 people authored and zmsn-2077 committed Mar 14, 2023
1 parent e81827b commit de04bb2
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 21 deletions.
14 changes: 9 additions & 5 deletions examples/evaluate_saved_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@


# Just fill your experiment's log directory in here.
# Such as: ~/omnisafe/runs/SafetyPointGoal1-v0/CPO/seed-000-2022-12-25_14-45-05
# Such as: ~/omnisafe/examples/runs/PPOLag-<SafetyPointGoal1-v0>/seed-000-2023-03-07-20-25-48
LOG_DIR = ''

play = True
save_replay = True
if __name__ == '__main__':
evaluator = omnisafe.Evaluator()
evaluator = omnisafe.Evaluator(play=play, save_replay=save_replay)
for item in os.scandir(os.path.join(LOG_DIR, 'torch_save')):
if item.is_file() and item.name.split('.')[-1] == 'pt':
evaluator.load_saved_model(save_dir=LOG_DIR, model_name=item.name)
evaluator.render(num_episode=10, camera_name='track', width=256, height=256)
evaluator.load_saved(
save_dir=LOG_DIR, model_name=item.name, camera_name='track', width=256, height=256
)
evaluator.render(num_episodes=1)
evaluator.evaluate(num_episodes=1)
1 change: 1 addition & 0 deletions omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from omnisafe import algorithms
from omnisafe.algorithms import ALGORITHMS
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent
from omnisafe.evaluator import Evaluator

# from omnisafe.algorithms.env_wrapper import EnvWrapper as Env
from omnisafe.version import __version__
Expand Down
8 changes: 6 additions & 2 deletions omnisafe/common/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Vector Buffer."""
"""Implementation of Normalizer."""

from typing import Tuple
from typing import Any, Mapping, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -108,3 +108,7 @@ def _push(self, raw_data: torch.Tensor) -> None:
self._var = self._sumsq / (self._count - 1)
self._std = torch.sqrt(self._var)
self._std = torch.max(self._std, 1e-2 * torch.ones_like(self._std))

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self._first = False
return super().load_state_dict(state_dict, strict)
10 changes: 10 additions & 0 deletions omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class CMDP(ABC):
_support_envs: List[str]
_action_space: OmnisafeSpace
_observation_space: OmnisafeSpace
_metadata: Dict[str, Any]

_num_envs: int
_time_limit: Optional[int] = None
Expand Down Expand Up @@ -86,6 +87,15 @@ def observation_space(self) -> OmnisafeSpace:
"""
return self._observation_space

@property
def metadata(self) -> Dict[str, Any]:
"""The metadata of the environment.
Returns:
Dict[str, Any]: the metadata.
"""
return self._metadata

@property
def num_envs(self) -> int:
"""The parallel environments.
Expand Down
1 change: 1 addition & 0 deletions omnisafe/envs/safety_gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None:
self._observation_space = self._env.observation_space

self._num_envs = num_envs
self._metadata = self._env.metadata

def step(
self, action: torch.Tensor
Expand Down
287 changes: 273 additions & 14 deletions omnisafe/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,295 @@
# ==============================================================================
"""Implementation of Evaluator."""

import json
import os
import warnings
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from gymnasium.spaces import Box
from gymnasium.utils.save_video import save_video

from omnisafe.common import Normalizer
from omnisafe.envs.core import CMDP, make
from omnisafe.envs.wrapper import ActionScale, ObsNormalize, TimeLimit
from omnisafe.models.actor import ActorBuilder
from omnisafe.models.base import Actor
from omnisafe.utils.config import Config


class Evaluator: # pylint: disable=too-many-instance-attributes
"""This class includes common evaluation methods for safe RL algorithms."""

def __init__(self) -> None:
pass
# pylint: disable-next=too-many-arguments
def __init__(
self,
play: bool = True,
save_replay: bool = True,
):
"""Initialize the evaluator.
Args:
env (gymnasium.Env): the environment. if None, the environment will be created from the config.
pi (omnisafe.algos.models.actor.Actor): the policy. if None, the policy will be created from the config.
obs_normalize (omnisafe.algos.models.obs_normalize): the observation Normalize.
"""
# set the attributes
self._env: CMDP
self._actor: Actor

# used when load model from saved file.
self._cfgs: Config
self._save_dir: str
self._model_name: str

# set the render mode
self._play = play
self._save_replay = save_replay

self._dividing_line = '\n' + '#' * 50 + '\n'

def load_saved_model(self, save_dir: str, model_name: str) -> None:
"""Load saved model from save_dir.
self.__set_render_mode(play, save_replay)

def __set_render_mode(self, play: bool = True, save_replay: bool = True):
"""Set the render mode.
Args:
save_dir (str): The directory of saved model.
model_name (str): The name of saved model.
play (bool): whether to play the video.
save_replay (bool): whether to save the video.
"""
# set the render mode
if play and save_replay:
self._render_mode = 'rgb_array'
elif play and not save_replay:
self._render_mode = 'human'
elif not play and save_replay:
self._render_mode = 'rgb_array_list'
else:
raise NotImplementedError('The render mode is not implemented.')

def __load_cfgs(self, save_dir: str):
"""Load the config from the save directory.
Args:
save_dir (str): directory where the model is saved.
"""
cfg_path = os.path.join(save_dir, 'config.json')
try:
with open(cfg_path, encoding='utf-8') as file:
kwargs = json.load(file)
except FileNotFoundError as error:
raise FileNotFoundError(
'The config file is not found in the save directory.'
) from error
self._cfgs = Config.dict2config(kwargs)

def load_running_model(self, env, actor) -> None:
"""Load running model from env and actor.
def __load_model_and_env(self, save_dir: str, model_name: str, env_kwargs: Dict[str, Any]):
"""Load the model from the save directory.
Args:
env (gym.Env): The environment.
actor (omnisafe.actor.Actor): The actor.
save_dir (str): directory where the model is saved.
model_name (str): name of the model.
"""
# load the saved model
model_path = os.path.join(save_dir, 'torch_save', model_name)
try:
model_params = torch.load(model_path)
except FileNotFoundError as error:
raise FileNotFoundError('The model is not found in the save directory.') from error

# load the environment
self._env = make(**env_kwargs)

observation_space = self._env.observation_space
action_space = self._env.action_space

assert isinstance(observation_space, Box), 'The observation space must be Box.'
assert isinstance(action_space, Box), 'The action space must be Box.'

if self._cfgs['algo_cfgs']['obs_normalize']:
obs_normalizer = Normalizer(shape=observation_space.shape, clip=5)
obs_normalizer.load_state_dict(model_params['obs_normalizer'])
self._env = ObsNormalize(self._env, obs_normalizer)
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000)
self._env = ActionScale(self._env, low=-1.0, high=1.0)

actor_type = self._cfgs['model_cfgs']['actor_type']
pi_cfg = self._cfgs['model_cfgs']['actor']
weight_initialization_mode = self._cfgs['model_cfgs']['weight_initialization_mode']
actor_builder = ActorBuilder(
obs_space=observation_space,
act_space=action_space,
hidden_sizes=pi_cfg['hidden_sizes'],
activation=pi_cfg['activation'],
weight_initialization_mode=weight_initialization_mode,
)
self._actor = actor_builder.build_actor(actor_type)
self._actor.load_state_dict(model_params['pi'])

# pylint: disable-next=too-many-locals
def load_saved(
self,
save_dir: str,
model_name: str,
camera_name: Optional[str] = None,
camera_id: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
):
"""Load a saved model.
Args:
save_dir (str): directory where the model is saved.
model_name (str): name of the model.
"""
# load the config
self._save_dir = save_dir
self._model_name = model_name

self.__load_cfgs(save_dir)

env_kwargs = {
'env_id': self._cfgs['env_id'],
'num_envs': 1,
'render_mode': self._render_mode,
'camera_id': camera_id,
'camera_name': camera_name,
'width': width,
'height': height,
}

def evaluate(self, num_episode: int, render: bool = False) -> None:
"""Evaluate the model.
self.__load_model_and_env(save_dir, model_name, env_kwargs)

def evaluate(
self,
num_episodes: int = 10,
cost_criteria: float = 1.0,
):
"""Evaluate the agent for num_episodes episodes.
Args:
num_episode (int): The number of episodes to evaluate.
render (bool): Whether to render the environment.
num_episodes (int): number of episodes to evaluate the agent.
cost_criteria (float): the cost criteria for the evaluation.
Returns:
(float, float, float): the average return, the average cost, and the average length of the episodes.
"""
if self._env is None or self._actor is None:
raise ValueError(
'The environment and the policy must be provided or created before evaluating the agent.'
)

episode_rewards: List[float] = []
episode_costs: List[float] = []
episode_lengths: List[float] = []

for episode in range(num_episodes):
obs, _ = self._env.reset()
ep_ret, ep_cost, length = 0.0, 0.0, 0.0

done = False
while not done:
with torch.no_grad():
act = self._actor.predict(
torch.as_tensor(obs, dtype=torch.float32),
deterministic=False,
)
obs, rew, cost, terminated, truncated, _ = self._env.step(act)

ep_ret += rew.item()
ep_cost += (cost_criteria**length) * cost.item()
length += 1

done = bool(terminated or truncated)

episode_rewards.append(ep_ret)
episode_costs.append(ep_cost)
episode_lengths.append(length)

print(f'Episode {episode+1} results:')
print(f'Episode reward: {ep_ret}')
print(f'Episode cost: {ep_cost}')
print(f'Episode length: {length}')

print(self._dividing_line)
print('Evaluation results:')
print(f'Average episode reward: {np.mean(episode_rewards)}')
print(f'Average episode cost: {np.mean(episode_costs)}')
print(f'Average episode length: {np.mean(episode_lengths)+1}')
return (
episode_rewards,
episode_costs,
)

@property
def fps(self) -> int:
"""The fps of the environment.
Returns:
int: the fps.
"""
try:
fps = self._env.metadata['render_fps']
except AttributeError:
fps = 30
warnings.warn('The fps is not found, use 30 as default.')

return fps

def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
self,
num_episodes: int = 0,
save_replay_path: Optional[str] = None,
max_render_steps: int = 2000,
):
"""Render the environment for one episode.
Args:
seed (int): seed for the environment. If None, the environment will be reset with a random seed.
save_replay_path (str): path to save the replay. If None, no replay is saved.
"""

if save_replay_path is None:
save_replay_path = os.path.join(self._save_dir, 'video', self._model_name.split('.')[0])

horizon = 1000
frames = []
obs, _ = self._env.reset()
if self._render_mode == 'human':
self._env.render()
elif self._render_mode == 'rgb_array':
frames.append(self._env.render())

for episode_idx in range(num_episodes):
step = 0
done = False
while (
not done and step <= max_render_steps
): # a big number to make sure the episode will end
with torch.no_grad():
act = self._actor.predict(obs, deterministic=False)
obs, _, _, terminated, truncated, _ = self._env.step(act)
step += 1
done = bool(terminated or truncated)

if self._render_mode == 'rgb_array':
frames.append(self._env.render())

if self._render_mode == 'rgb_array_list':
frames = self._env.render()

if save_replay_path is not None:
save_video(
frames,
save_replay_path,
fps=self.fps,
episode_trigger=lambda x: True,
video_length=horizon,
episode_index=episode_idx,
name_prefix='eval',
)
self._env.reset()
frames = []

0 comments on commit de04bb2

Please sign in to comment.