-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support policy evaluation #137
Changes from 6 commits
badc42e
abdd611
2042d07
f6f4208
25f9dd2
648ba4c
75ba123
53dd79a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
# ============================================================================== | ||
"""Implementation of Vector Buffer.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vector Buffer should not appear in this file. |
||
|
||
from typing import Tuple | ||
from typing import Any, Mapping, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -108,3 +108,7 @@ def _push(self, raw_data: torch.Tensor) -> None: | |
self._var = self._sumsq / (self._count - 1) | ||
self._std = torch.sqrt(self._var) | ||
self._std = torch.max(self._std, 1e-2 * torch.ones_like(self._std)) | ||
|
||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): | ||
self._first = False | ||
return super().load_state_dict(state_dict, strict) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,36 +14,290 @@ | |
# ============================================================================== | ||
"""Implementation of Evaluator.""" | ||
|
||
import json | ||
import os | ||
import warnings | ||
from typing import Any, Dict, List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from gymnasium.spaces import Box | ||
from gymnasium.utils.save_video import save_video | ||
|
||
from omnisafe.common import Normalizer | ||
from omnisafe.envs.core import CMDP, make | ||
from omnisafe.envs.wrapper import ActionScale, ObsNormalize, TimeLimit | ||
from omnisafe.models.actor import ActorBuilder | ||
from omnisafe.models.base import Actor | ||
from omnisafe.utils.config import Config | ||
|
||
|
||
class Evaluator: # pylint: disable=too-many-instance-attributes | ||
"""This class includes common evaluation methods for safe RL algorithms.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
# pylint: disable-next=too-many-arguments | ||
def __init__( | ||
self, | ||
play: bool = True, | ||
save_replay: bool = True, | ||
): | ||
"""Initialize the evaluator. | ||
|
||
Args: | ||
env (gymnasium.Env): the environment. if None, the environment will be created from the config. | ||
pi (omnisafe.algos.models.actor.Actor): the policy. if None, the policy will be created from the config. | ||
obs_normalize (omnisafe.algos.models.obs_normalize): the observation Normalize. | ||
""" | ||
# set the attributes | ||
self._env: CMDP | ||
self._actor: Actor | ||
|
||
def load_saved_model(self, save_dir: str, model_name: str) -> None: | ||
"""Load saved model from save_dir. | ||
# used when load model from saved file. | ||
self._cfgs: Config | ||
self._save_dir: str | ||
self._model_name: str | ||
|
||
# set the render mode | ||
self._play = play | ||
self._save_replay = save_replay | ||
self.__set_render_mode(play, save_replay) | ||
|
||
def __set_render_mode(self, play: bool = True, save_replay: bool = True): | ||
"""Set the render mode. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
save_dir (str): The directory of saved model. | ||
model_name (str): The name of saved model. | ||
render_mode (str): render mode. | ||
""" | ||
# set the render mode | ||
if play and save_replay: | ||
self._render_mode = 'rgb_array' | ||
elif play and not save_replay: | ||
self._render_mode = 'human' | ||
elif not play and save_replay: | ||
self._render_mode = 'rgb_array_list' | ||
else: | ||
raise NotImplementedError('The render mode is not implemented.') | ||
|
||
def __load_cfgs(self, save_dir: str): | ||
"""Load the config from the save directory. | ||
|
||
Args: | ||
save_dir (str): directory where the model is saved. | ||
""" | ||
cfg_path = os.path.join(save_dir, 'config.json') | ||
try: | ||
with open(cfg_path, encoding='utf-8') as file: | ||
kwargs = json.load(file) | ||
except FileNotFoundError as error: | ||
raise FileNotFoundError( | ||
'The config file is not found in the save directory.' | ||
) from error | ||
self._cfgs = Config.dict2config(kwargs) | ||
|
||
def load_running_model(self, env, actor) -> None: | ||
"""Load running model from env and actor. | ||
def __load_model_and_env(self, save_dir: str, model_name: str, env_kwargs: Dict[str, Any]): | ||
"""Load the model from the save directory. | ||
|
||
Args: | ||
env (gym.Env): The environment. | ||
actor (omnisafe.actor.Actor): The actor. | ||
save_dir (str): directory where the model is saved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the same problems |
||
model_name (str): name of the model. | ||
""" | ||
# load the saved model | ||
model_path = os.path.join(save_dir, 'torch_save', model_name) | ||
try: | ||
model_params = torch.load(model_path) | ||
except FileNotFoundError as error: | ||
raise FileNotFoundError('The model is not found in the save directory.') from error | ||
|
||
# load the environment | ||
self._env = make(**env_kwargs) | ||
|
||
observation_space = self._env.observation_space | ||
action_space = self._env.action_space | ||
|
||
assert isinstance(observation_space, Box), 'The observation space must be Box.' | ||
assert isinstance(action_space, Box), 'The action space must be Box.' | ||
|
||
if self._cfgs['algo_cfgs']['obs_normalize']: | ||
obs_normalizer = Normalizer(shape=observation_space.shape, clip=5) | ||
obs_normalizer.load_state_dict(model_params['obs_normalizer']) | ||
self._env = ObsNormalize(self._env, obs_normalizer) | ||
if self._env.need_time_limit_wrapper: | ||
self._env = TimeLimit(self._env, time_limit=1000) | ||
self._env = ActionScale(self._env, low=-1.0, high=1.0) | ||
|
||
actor_type = self._cfgs['model_cfgs']['actor_type'] | ||
pi_cfg = self._cfgs['model_cfgs']['actor'] | ||
weight_initialization_mode = self._cfgs['model_cfgs']['weight_initialization_mode'] | ||
actor_builder = ActorBuilder( | ||
obs_space=observation_space, | ||
act_space=action_space, | ||
hidden_sizes=pi_cfg['hidden_sizes'], | ||
activation=pi_cfg['activation'], | ||
weight_initialization_mode=weight_initialization_mode, | ||
) | ||
self._actor = actor_builder.build_actor(actor_type) | ||
self._actor.load_state_dict(model_params['pi']) | ||
|
||
# pylint: disable-next=too-many-locals | ||
def load_saved( | ||
self, | ||
save_dir: str, | ||
model_name: str, | ||
camera_name: Optional[str] = None, | ||
camera_id: Optional[int] = None, | ||
width: Optional[int] = None, | ||
height: Optional[int] = None, | ||
): | ||
"""Load a saved model. | ||
|
||
Args: | ||
save_dir (str): directory where the model is saved. | ||
model_name (str): name of the model. | ||
""" | ||
# load the config | ||
self._save_dir = save_dir | ||
self._model_name = model_name | ||
|
||
self.__load_cfgs(save_dir) | ||
|
||
env_kwargs = { | ||
'env_id': self._cfgs['env_id'], | ||
'num_envs': 1, | ||
'render_mode': self._render_mode, | ||
'camera_id': camera_id, | ||
'camera_name': camera_name, | ||
'width': width, | ||
'height': height, | ||
} | ||
|
||
self.__load_model_and_env(save_dir, model_name, env_kwargs) | ||
|
||
def evaluate(self, num_episode: int, render: bool = False) -> None: | ||
"""Evaluate the model. | ||
def evaluate( | ||
self, | ||
num_episodes: int = 10, | ||
cost_criteria: float = 1.0, | ||
): | ||
"""Evaluate the agent for num_episodes episodes. | ||
|
||
Args: | ||
num_episode (int): The number of episodes to evaluate. | ||
render (bool): Whether to render the environment. | ||
num_episodes (int): number of episodes to evaluate the agent. | ||
cost_criteria (float): the cost criteria for the evaluation. | ||
|
||
Returns: | ||
episode_rewards (list): list of episode rewards. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to fix, |
||
episode_costs (list): list of episode costs. | ||
episode_lengths (list): list of episode lengths. | ||
""" | ||
if self._env is None or self._actor is None: | ||
raise ValueError( | ||
'The environment and the policy must be provided or created before evaluating the agent.' | ||
) | ||
|
||
episode_rewards: List[float] = [] | ||
episode_costs: List[float] = [] | ||
episode_lengths: List[float] = [] | ||
|
||
for episode in range(num_episodes): | ||
obs, _ = self._env.reset() | ||
ep_ret, ep_cost, length = 0.0, 0.0, 0.0 | ||
|
||
done = False | ||
while not done: | ||
with torch.no_grad(): | ||
act = self._actor.predict( | ||
torch.as_tensor(obs, dtype=torch.float32), | ||
deterministic=False, | ||
) | ||
obs, rew, cost, terminated, truncated, _ = self._env.step(act) | ||
|
||
ep_ret += rew.item() | ||
ep_cost += (cost_criteria**length) * cost.item() | ||
length += 1 | ||
|
||
done = bool(terminated or truncated) | ||
|
||
episode_rewards.append(ep_ret) | ||
episode_costs.append(ep_cost) | ||
episode_lengths.append(length) | ||
|
||
print(f'Episode {episode+1} results:') | ||
print(f'Episode reward: {ep_ret}') | ||
print(f'Episode cost: {ep_cost}') | ||
print(f'Episode length: {length}') | ||
|
||
print('#' * 50) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
print('Evaluation results:') | ||
print(f'Average episode reward: {np.mean(episode_rewards)}') | ||
print(f'Average episode cost: {np.mean(episode_costs)}') | ||
print(f'Average episode length: {np.mean(episode_lengths)+1}') | ||
return ( | ||
episode_rewards, | ||
episode_costs, | ||
) | ||
|
||
@property | ||
def fps(self) -> int: | ||
"""The fps of the environment. | ||
|
||
Returns: | ||
int: the fps. | ||
""" | ||
try: | ||
fps = self._env.metadata['render_fps'] | ||
except AttributeError: | ||
fps = 30 | ||
warnings.warn('The fps is not found, use 30 as default.') | ||
|
||
return fps | ||
|
||
def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements | ||
self, | ||
num_episodes: int = 0, | ||
save_replay_path: Optional[str] = None, | ||
): | ||
"""Render the environment for one episode. | ||
|
||
Args: | ||
seed (int): seed for the environment. If None, the environment will be reset with a random seed. | ||
save_replay_path (str): path to save the replay. If None, no replay is saved. | ||
""" | ||
|
||
if save_replay_path is None: | ||
save_replay_path = os.path.join(self._save_dir, 'video', self._model_name.split('.')[0]) | ||
|
||
horizon = 1000 | ||
frames = [] | ||
obs, _ = self._env.reset() | ||
if self._render_mode == 'human': | ||
self._env.render() | ||
elif self._render_mode == 'rgb_array': | ||
frames.append(self._env.render()) | ||
|
||
for episode_idx in range(num_episodes): | ||
step = 0 | ||
done = False | ||
while not done and step <= 2000: # a big number to make sure the episode will end | ||
with torch.no_grad(): | ||
act = self._actor.predict(obs, deterministic=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why deterministic = False? |
||
obs, _, _, terminated, truncated, _ = self._env.step(act) | ||
step += 1 | ||
done = bool(terminated or truncated) | ||
|
||
if self._render_mode == 'rgb_array': | ||
frames.append(self._env.render()) | ||
|
||
if self._render_mode == 'rgb_array_list': | ||
frames = self._env.render() | ||
|
||
if save_replay_path is not None: | ||
save_video( | ||
frames, | ||
save_replay_path, | ||
fps=self.fps, | ||
episode_trigger=lambda x: True, | ||
video_length=horizon, | ||
episode_index=episode_idx, | ||
name_prefix='eval', | ||
) | ||
self._env.reset() | ||
frames = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.