From c41345b8220481bb19476a6767b504693e0e23bd Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Wed, 26 Oct 2022 23:04:38 -0400 Subject: [PATCH 01/11] save videos during training --- src/imitation/scripts/common/common.py | 2 +- src/imitation/scripts/common/rl.py | 1 + src/imitation/scripts/common/train.py | 2 + src/imitation/scripts/train_adversarial.py | 10 ++- .../scripts/train_preference_comparisons.py | 15 +++- src/imitation/scripts/train_rl.py | 14 ++-- src/imitation/util/logger.py | 5 +- src/imitation/util/video_wrapper.py | 37 +++++++++- tests/scripts/test_scripts.py | 69 +++++++++++++++++-- tests/util/test_wb_logger.py | 5 ++ 10 files changed, 139 insertions(+), 21 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 72d44f2f4..f2eba87d8 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -177,4 +177,4 @@ def make_venv( try: yield venv finally: - venv.close() + venv.close() \ No newline at end of file diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index 2bd3759a2..67ec0d785 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -154,6 +154,7 @@ def make_rl_algo( ) else: raise TypeError(f"Unsupported RL algorithm '{rl_cls}'") + rl_algo = rl_cls( policy=train["policy_cls"], # Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass. diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index bd7c7d546..7d6a3f689 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -10,6 +10,8 @@ from imitation.data import rollout from imitation.policies import base from imitation.scripts.common import common +from imitation.util import video_wrapper + train_ingredient = sacred.Ingredient("train", ingredients=[common.common_ingredient]) logger = logging.getLogger(__name__) diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index b84aec720..a540d5fdf 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -9,6 +9,7 @@ import torch as th from sacred.observers import FileStorageObserver +import imitation.util.video_wrapper as video_wrapper from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common from imitation.algorithms.adversarial import gail as gail_algo @@ -111,9 +112,16 @@ def train_adversarial( sacred.commands.print_config(_run) custom_logger, log_dir = common_config.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + video_dir = checkpoint_dir / "videos" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) + expert_trajs = demonstrations.get_expert_trajectories() - with common_config.make_venv() as venv: + post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] if checkpoint_interval > 0 else None + + with common_config.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 331a4797a..dd24be05e 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -12,6 +12,8 @@ from sacred.observers import FileStorageObserver from stable_baselines3.common import type_aliases +import gym +import imitation.util.video_wrapper as video_wrapper from imitation.algorithms import preference_comparisons from imitation.data import types from imitation.policies import serialize @@ -149,14 +151,24 @@ def train_preference_comparisons( ValueError: Inconsistency between config and deserialized policy normalization. """ custom_logger, log_dir = common.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + video_dir = checkpoint_dir / "videos" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) + rng = common.make_rng() - with common.make_venv() as venv: + post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] if checkpoint_interval > 0 else None + with common.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, update_stats=False, ) + + #if checkpoint_interval > 0: + # venv = VideoWrapper(venv, directory=video_dir, cadence=checkpoint_interval) + if agent_path is None: agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) else: @@ -287,6 +299,5 @@ def main_console(): train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() - if __name__ == "__main__": # pragma: no cover main_console() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index fb7959592..a4ffd8047 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -1,5 +1,4 @@ """Uses RL to train a policy from scratch, saving rollouts and policy. - This can be used: 1. To train a policy on a ground-truth reward function, as a source of synthetic "expert" demonstrations to train IRL or imitation learning @@ -17,6 +16,7 @@ from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecNormalize +import imitation.util.video_wrapper as video_wrapper from imitation.data import rollout, types, wrappers from imitation.policies import serialize from imitation.rewards.reward_wrapper import RewardVecEnvWrapper @@ -42,14 +42,11 @@ def train_rl( agent_path: Optional[str], ) -> Mapping[str, float]: """Trains an expert policy from scratch and saves the rollouts and policy. - Checkpoints: At applicable training steps `step` (where step is either an integer or "final"): - - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. - Args: total_timesteps: Number of training timesteps in `model.learn()`. normalize_reward: Applies normalization and clipping to the reward function by @@ -82,7 +79,6 @@ def train_rl( policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. - Returns: The return value of `rollout_stats()` using the final policy. """ @@ -90,10 +86,16 @@ def train_rl( custom_logger, log_dir = common.setup_logging() rollout_dir = log_dir / "rollouts" policy_dir = log_dir / "policies" + video_dir = log_dir / "videos" rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] + + if policy_save_interval > 0: + post_wrappers.append(video_wrapper.video_wrapper_factory(video_dir, policy_save_interval)) + with common.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] if reward_type is not None: @@ -164,4 +166,4 @@ def main_console(): if __name__ == "__main__": # pragma: no cover - main_console() + main_console() \ No newline at end of file diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 70190b1fb..c3ca758ba 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -377,7 +377,10 @@ def write( if excluded is not None and "wandb" in excluded: continue - self.wandb_module.log({key: value}, step=step) + if key != "video": + self.wandb_module.log({key: value}, step=step) + else: + self.wandb_module.log({"video": self.wandb_module.Video(value)}) self.wandb_module.log({}, commit=True) def close(self) -> None: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index a59641aa1..d2d262689 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,7 +1,7 @@ """Wrapper to record rendered video frames from an environment.""" import pathlib -from typing import Optional +from typing import Optional, Callable import gym from gym.wrappers.monitoring import video_recorder @@ -14,12 +14,16 @@ class VideoWrapper(gym.Wrapper): video_recorder: Optional[video_recorder.VideoRecorder] single_video: bool directory: pathlib.Path + cadence: int + should_record: bool + step_count: int def __init__( self, env: gym.Env, directory: pathlib.Path, single_video: bool = True, + cadence: int = 1 ): """Builds a VideoWrapper. @@ -31,14 +35,21 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. + cadence: the video wrapper will save a video of the next episode that begins + after every Nth step. So if cadence=100 and each episode has 30 steps, it will + record the 4th episode(first to start after step_count=100) and then the 7th + episode (first to start after step_count=200). """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video + self.cadence = cadence self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) + self.should_record = False + self.step_count = 0 def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -53,13 +64,14 @@ def _reset_video_recorder(self) -> None: self.video_recorder.close() self.video_recorder = None - if self.video_recorder is None: + if self.video_recorder is None and (self.should_record or self.step_count % self.cadence == 0): # No video recorder -- start a new one. self.video_recorder = video_recorder.VideoRecorder( env=self.env, base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) + self.should_record = False def reset(self): self._reset_video_recorder() @@ -68,7 +80,11 @@ def reset(self): def step(self, action): res = self.env.step(action) - self.video_recorder.capture_frame() + self.step_count += 1 + if self.step_count % self.cadence == 0: + self.should_record == 0 + if self.video_recorder != None: + self.video_recorder.capture_frame() return res def close(self) -> None: @@ -76,3 +92,18 @@ def close(self) -> None: self.video_recorder.close() self.video_recorder = None super().close() + + +def video_wrapper_factory(video_dir: pathlib.Path, cadence: int, **kwargs) -> Callable: + def f(env: gym.Env, i: int) -> VideoWrapper: + """ + Returns a wrapper around a gym environment records a video if and only if i is 0 + + Args: + env: the environment to be wrapped around + i: the index of the environment. This is to make the video wrapper compatible with + vectorized environments. Only environments with i=0 actually attach the VideoWrapper + """ + + return VideoWrapper(env, directory=video_dir, cadence=cadence, **kwargs) if i == 0 else env + return f \ No newline at end of file diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index fc138b40b..41caacc81 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -1,5 +1,4 @@ """Smoke tests for CLI programs in `imitation.scripts.*`. - Every test in this file should use `parallel=False` to turn off multiprocessing because codecov might interact poorly with multiprocessing. The 'fast' named_config for each experiment implicitly sets parallel=False. @@ -74,9 +73,7 @@ @pytest.fixture(autouse=True) def sacred_capture_use_sys(): """Set Sacred capture mode to "sys" because default "fd" option leads to error. - See https://github.com/IDSIA/sacred/issues/289. - Yields: None after setting capture mode; restores it after yield. """ @@ -602,9 +599,7 @@ def test_train_adversarial_algorithm_value_error(tmpdir): def test_transfer_learning(tmpdir: str) -> None: """Transfer learning smoke test. - Saves a dummy AIRL test reward, then loads it for transfer learning. - Args: tmpdir: Temporary directory to save results to. """ @@ -650,9 +645,7 @@ def test_preference_comparisons_transfer_learning( named_configs_dict: Mapping[str, List[str]], ) -> None: """Transfer learning smoke test. - Saves a preference comparisons ensemble reward, then loads it for transfer learning. - Args: tmpdir: Temporary directory to save results to. named_configs_dict: Named configs for preference_comparisons and rl. @@ -953,3 +946,65 @@ def test_convert_trajs(tmpdir: str): assert len(from_pkl) == len(from_npz) for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + + +_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} +# Change the following if the file structure of checkpoints changed. +VIDEO_PATH_DICT = dict( + rl=lambda d: d / "videos", + adversarial=lambda d: d / "checkpoints" / "videos", + pc=lambda d: d / "checkpoints" / "videos", + bc=lambda d: n_envs.join(d, "videos"), +) + +def _check_video_exists(log_dir, algo, video_name): + video_dir = VIDEO_PATH_DICT[algo](log_dir) + assert os.path.exists(video_dir) + assert video_name in os.listdir(video_dir) + + +def test_train_rl_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_rl.""" + config_updates = dict( + common=dict(log_root=tmpdir), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_rl.train_rl_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "rl", "video.{:06}.mp4".format(0)) + + +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + checkpoint_interval=1, + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "adversarial", "video.{:06}.mp4".format(0)) + + +def test_train_preference_comparisons_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + checkpoint_interval=1, + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "pc", "video.{:06}.mp4".format(0)) \ No newline at end of file diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index f3b1a85a9..45da780c2 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -113,6 +113,11 @@ def test_wandb_output_format(): {"_step": 0, "foo": 42, "fizz": 12}, {"_step": 3, "fizz": 21}, ] + + with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): + log_obj.record("video", 42) + log_obj.dump(step=4) + log_obj.close() From a4211ff6edbff8a24919cc9ffa0e05fc01fdf2aa Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Wed, 26 Oct 2022 23:04:38 -0400 Subject: [PATCH 02/11] save videos during training --- src/imitation/scripts/common/common.py | 2 +- src/imitation/scripts/common/rl.py | 1 + src/imitation/scripts/train_adversarial.py | 12 +++- .../scripts/train_preference_comparisons.py | 15 ++++- src/imitation/scripts/train_rl.py | 15 +++-- src/imitation/util/logger.py | 5 +- src/imitation/util/video_wrapper.py | 39 +++++++++-- tests/scripts/test_scripts.py | 65 +++++++++++++++++-- tests/util/test_wb_logger.py | 5 ++ 9 files changed, 140 insertions(+), 19 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 72d44f2f4..f2eba87d8 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -177,4 +177,4 @@ def make_venv( try: yield venv finally: - venv.close() + venv.close() \ No newline at end of file diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index 2bd3759a2..67ec0d785 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -154,6 +154,7 @@ def make_rl_algo( ) else: raise TypeError(f"Unsupported RL algorithm '{rl_cls}'") + rl_algo = rl_cls( policy=train["policy_cls"], # Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass. diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index b84aec720..7ed96e179 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -9,6 +9,7 @@ import torch as th from sacred.observers import FileStorageObserver +import imitation.util.video_wrapper as video_wrapper from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common from imitation.algorithms.adversarial import gail as gail_algo @@ -111,9 +112,18 @@ def train_adversarial( sacred.commands.print_config(_run) custom_logger, log_dir = common_config.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + video_dir = checkpoint_dir / "videos" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) + expert_trajs = demonstrations.get_expert_trajectories() - with common_config.make_venv() as venv: + post_wrappers = None + if checkpoint_interval > 0: + post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] + + with common_config.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 331a4797a..d69e75090 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -21,6 +21,8 @@ from imitation.scripts.config.train_preference_comparisons import ( train_preference_comparisons_ex, ) +import imitation.util.video_wrapper as video_wrapper + def save_model( @@ -149,14 +151,24 @@ def train_preference_comparisons( ValueError: Inconsistency between config and deserialized policy normalization. """ custom_logger, log_dir = common.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + video_dir = checkpoint_dir / "videos" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) + rng = common.make_rng() - with common.make_venv() as venv: + post_wrappers = None + if checkpoint_interval > 0: + post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] + + with common.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, update_stats=False, ) + if agent_path is None: agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) else: @@ -287,6 +299,5 @@ def main_console(): train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() - if __name__ == "__main__": # pragma: no cover main_console() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index fb7959592..3c0b14e1d 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -1,5 +1,4 @@ """Uses RL to train a policy from scratch, saving rollouts and policy. - This can be used: 1. To train a policy on a ground-truth reward function, as a source of synthetic "expert" demonstrations to train IRL or imitation learning @@ -17,6 +16,7 @@ from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecNormalize +import imitation.util.video_wrapper as video_wrapper from imitation.data import rollout, types, wrappers from imitation.policies import serialize from imitation.rewards.reward_wrapper import RewardVecEnvWrapper @@ -41,15 +41,13 @@ def train_rl( policy_save_final: bool, agent_path: Optional[str], ) -> Mapping[str, float]: - """Trains an expert policy from scratch and saves the rollouts and policy. + """Trains an expert policy from scratch and saves the rollouts and policy. Checkpoints: At applicable training steps `step` (where step is either an integer or "final"): - - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. - Args: total_timesteps: Number of training timesteps in `model.learn()`. normalize_reward: Applies normalization and clipping to the reward function by @@ -82,7 +80,6 @@ def train_rl( policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. - Returns: The return value of `rollout_stats()` using the final policy. """ @@ -90,10 +87,18 @@ def train_rl( custom_logger, log_dir = common.setup_logging() rollout_dir = log_dir / "rollouts" policy_dir = log_dir / "policies" + video_dir = log_dir / "videos" rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] + + if policy_save_interval > 0: + post_wrappers.append( + video_wrapper.video_wrapper_factory(video_dir, policy_save_interval) + ) + with common.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] if reward_type is not None: diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 70190b1fb..c3ca758ba 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -377,7 +377,10 @@ def write( if excluded is not None and "wandb" in excluded: continue - self.wandb_module.log({key: value}, step=step) + if key != "video": + self.wandb_module.log({key: value}, step=step) + else: + self.wandb_module.log({"video": self.wandb_module.Video(value)}) self.wandb_module.log({}, commit=True) def close(self) -> None: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index a59641aa1..e24469746 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,8 +1,7 @@ """Wrapper to record rendered video frames from an environment.""" import pathlib -from typing import Optional - +from typing import Optional, Callable import gym from gym.wrappers.monitoring import video_recorder @@ -14,12 +13,16 @@ class VideoWrapper(gym.Wrapper): video_recorder: Optional[video_recorder.VideoRecorder] single_video: bool directory: pathlib.Path + cadence: int + should_record: bool + step_count: int def __init__( self, env: gym.Env, directory: pathlib.Path, single_video: bool = True, + cadence: int = 1, ): """Builds a VideoWrapper. @@ -31,14 +34,22 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. + cadence: the video wrapper will save a video of the next episode that + begins after every Nth step. So if cadence=100 and each episode has + 30 steps, it will record the 4th episode(first to start after + step_count=100) and then the 7thepisode (first to start after + step_count=200). """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video + self.cadence = cadence self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) + self.should_record = False + self.step_count = 0 def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -53,13 +64,14 @@ def _reset_video_recorder(self) -> None: self.video_recorder.close() self.video_recorder = None - if self.video_recorder is None: + if self.video_recorder is None and (self.should_record or self.step_count % self.cadence == 0): # No video recorder -- start a new one. self.video_recorder = video_recorder.VideoRecorder( env=self.env, base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) + self.should_record = False def reset(self): self._reset_video_recorder() @@ -68,7 +80,11 @@ def reset(self): def step(self, action): res = self.env.step(action) - self.video_recorder.capture_frame() + self.step_count += 1 + if self.step_count % self.cadence == 0: + self.should_record == 0 + if self.video_recorder != None: + self.video_recorder.capture_frame() return res def close(self) -> None: @@ -76,3 +92,18 @@ def close(self) -> None: self.video_recorder.close() self.video_recorder = None super().close() + + +def video_wrapper_factory(video_dir: pathlib.Path, cadence: int, **kwargs) -> Callable: + def f(env: gym.Env, i: int) -> VideoWrapper: + """ + Returns a wrapper around a gym environment records a video if and only if i is 0 + + Args: + env: the environment to be wrapped around + i: the index of the environment. This is to make the video wrapper compatible with + vectorized environments. Only environments with i=0 actually attach the VideoWrapper + """ + + return VideoWrapper(env, directory=video_dir, cadence=cadence, **kwargs) if i == 0 else env + return f diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index fc138b40b..821ec1994 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -73,10 +73,9 @@ @pytest.fixture(autouse=True) def sacred_capture_use_sys(): - """Set Sacred capture mode to "sys" because default "fd" option leads to error. + """Set Sacred capture mode to "sys" because default "fd" option leads to error. See https://github.com/IDSIA/sacred/issues/289. - Yields: None after setting capture mode; restores it after yield. """ @@ -604,7 +603,6 @@ def test_transfer_learning(tmpdir: str) -> None: """Transfer learning smoke test. Saves a dummy AIRL test reward, then loads it for transfer learning. - Args: tmpdir: Temporary directory to save results to. """ @@ -649,10 +647,9 @@ def test_preference_comparisons_transfer_learning( tmpdir: str, named_configs_dict: Mapping[str, List[str]], ) -> None: - """Transfer learning smoke test. + """Transfer learning smoke test. Saves a preference comparisons ensemble reward, then loads it for transfer learning. - Args: tmpdir: Temporary directory to save results to. named_configs_dict: Named configs for preference_comparisons and rl. @@ -953,3 +950,61 @@ def test_convert_trajs(tmpdir: str): assert len(from_pkl) == len(from_npz) for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + + +#_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} +# Change the following if the file structure of checkpoints changed. +VIDEO_FILE_PATH = "video.{:06}.mp4".format(0) +VIDEO_PATH_DICT = dict( + rl=lambda d: d / "videos", + adversarial=lambda d: d / "checkpoints" / "videos", + pc=lambda d: d / "checkpoints" / "videos" +) + + +def _check_video_exists(log_dir, algo): + video_dir = VIDEO_PATH_DICT[algo](log_dir) + assert os.path.exists(video_dir) + assert VIDEO_FILE_PATH in os.listdir(video_dir) + + +def test_train_rl_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_rl.""" + config_updates = dict( + common=dict(log_root=tmpdir) ) + run = train_rl.train_rl_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "rl") + + +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + checkpoint_interval=1 ) + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "adversarial") + + +def test_train_preference_comparisons_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + checkpoint_interval=1 + ) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "pc") diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index f3b1a85a9..7bbfbdb63 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -113,6 +113,11 @@ def test_wandb_output_format(): {"_step": 0, "foo": 42, "fizz": 12}, {"_step": 3, "fizz": 21}, ] + + with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): + log_obj.record("video", 42) + log_obj.dump(step=4) + log_obj.close() From 5e6ab2509d32c66b12a5c953ad9b554b335450ec Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Fri, 28 Oct 2022 12:59:41 -0400 Subject: [PATCH 03/11] refactored video saving and enabled it for bc imitation --- src/imitation/scripts/common/common.py | 51 ++++++++++++++++++- src/imitation/scripts/common/train.py | 2 - .../scripts/config/train_adversarial.py | 1 + .../scripts/config/train_imitation.py | 1 + .../config/train_preference_comparisons.py | 1 + src/imitation/scripts/config/train_rl.py | 1 + src/imitation/scripts/train_adversarial.py | 19 ++++--- src/imitation/scripts/train_imitation.py | 13 ++++- .../scripts/train_preference_comparisons.py | 22 ++++---- src/imitation/scripts/train_rl.py | 17 ++++--- src/imitation/util/video_wrapper.py | 34 ++++++++----- tests/scripts/test_scripts.py | 39 +++++++++++--- 12 files changed, 152 insertions(+), 49 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index f2eba87d8..98a9f686c 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,8 +3,9 @@ import contextlib import logging import pathlib -from typing import Any, Generator, Mapping, Sequence, Tuple, Union +from typing import Any, Generator, Mapping, Sequence, Tuple, Union, Optional, Callable +import gym import numpy as np import sacred from stable_baselines3.common import vec_env @@ -14,6 +15,7 @@ from imitation.util import logger as imit_logger from imitation.util import sacred as sacred_util from imitation.util import util +import imitation.util.video_wrapper as video_wrapper common_ingredient = sacred.Ingredient("common", ingredients=[wb.wandb_ingredient]) logger = logging.getLogger(__name__) @@ -132,6 +134,51 @@ def setup_logging( ) return custom_logger, log_dir +@common_ingredient.capture +def setup_video_saving( + _run, + base_dir: pathlib.Path, + video_save_interval: int, + post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None, +) -> Optional[Sequence[Callable[[gym.Env, int], gym.Env]]]: + """Adds video saving to the existing post wrappers. + + Args: + base_dir: the videos will be saved to a videos subdirectory under + this base directory + video_save_interval: the video wrapper will save a video of the next + episode that begins after every Nth step. So if + video_save_interval=100 and each episode has 30 steps, it will record + the 4th episode(first to start after step_count=100) and then the 7th + episode (first to start after step_count=200). + post_wrappers: If specified, iteratively wraps each environment with each + of the wrappers specified in the sequence. The argument should be a + Callable accepting two arguments, the Env to be wrapped and the + environment index, and returning the wrapped Env. + + Returns: + A new post_wrapper list with the video saving wrapper appended to the + existing list. If the existing post wrapper was null, it will create a + new list with just the video wrapper. If the video_save_interval is <=0, + it will just return the inputted post_wrapper + """ + + if video_save_interval > 0: + + video_dir = base_dir / "videos" + video_dir.mkdir(parents=True, exist_ok=True) + + post_wrappers_copy = [wrapper for wrapper in post_wrappers] \ + if post_wrappers != None else [] + post_wrappers_copy.append( + video_wrapper.video_wrapper_factory(video_dir, video_save_interval) + ) + + return post_wrappers_copy + + return post_wrappers + + @contextlib.contextmanager @common_ingredient.capture @@ -177,4 +224,4 @@ def make_venv( try: yield venv finally: - venv.close() \ No newline at end of file + venv.close() diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index 7d6a3f689..bd7c7d546 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -10,8 +10,6 @@ from imitation.data import rollout from imitation.policies import base from imitation.scripts.common import common -from imitation.util import video_wrapper - train_ingredient = sacred.Ingredient("train", ingredients=[common.common_ingredient]) logger = logging.getLogger(__name__) diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 3183ac9f6..3779350eb 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -30,6 +30,7 @@ def defaults(): algorithm_specific = {} # algorithm_specific[algorithm] is merged with config checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) agent_path = None # Path to load agent from, optional. diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index c2466a936..71f0d6f6b 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -38,6 +38,7 @@ def config(): total_timesteps=1e5, ) agent_path = None # Path to load agent from, optional. + video_save_interval = 0 # <=0 means no saving @train_imitation_ex.named_config diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..9f0d554fd 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -58,6 +58,7 @@ def train_defaults(): allow_variable_horizon = False checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) query_schedule = "hyperbolic" diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index b9ede3165..b5ed922c0 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -30,6 +30,7 @@ def train_rl_defaults(): rollout_save_n_episodes = None # Num episodes saved per file, optional. policy_save_interval = 10000 # Num timesteps between saves (<=0 disables) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) policy_save_final = True # If True, save after training is finished. agent_path = None # Path to load agent from, optional. diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 7ed96e179..473aac704 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -72,6 +72,7 @@ def train_adversarial( algorithm_kwargs: Mapping[str, Any], total_timesteps: int, checkpoint_interval: int, + video_save_interval: int, agent_path: Optional[str], ) -> Mapping[str, Mapping[str, float]]: """Train an adversarial-network-based imitation learning algorithm. @@ -94,6 +95,10 @@ def train_adversarial( `checkpoint_interval` rounds and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. agent_path: Path to a directory containing a pre-trained agent. If provided, then the agent will be initialized using this stored policy (warm start). If not provided, then the agent will be initialized using @@ -113,15 +118,15 @@ def train_adversarial( custom_logger, log_dir = common_config.setup_logging() checkpoint_dir = log_dir / "checkpoints" - video_dir = checkpoint_dir / "videos" checkpoint_dir.mkdir(parents=True, exist_ok=True) - video_dir.mkdir(parents=True, exist_ok=True) expert_trajs = demonstrations.get_expert_trajectories() - post_wrappers = None - if checkpoint_interval > 0: - post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] + post_wrappers = common_config.setup_video_saving( + base_dir=checkpoint_dir, + video_save_interval=video_save_interval, + post_wrappers=None + ) with common_config.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) @@ -159,14 +164,14 @@ def train_adversarial( def callback(round_num: int, /) -> None: if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") + save(trainer, checkpoint_dir / f"{round_num:05d}") trainer.train(total_timesteps, callback) imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) # Save final artifacts. if checkpoint_interval >= 0: - save(trainer, log_dir / "checkpoints" / "final") + save(trainer, checkpoint_dir / "final") return { "imit_stats": imit_stats, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 09393366e..8fd9b9132 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -71,6 +71,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], + video_save_interval: int ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. @@ -82,6 +83,10 @@ def train_imitation( agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. Returns: Statistics for rollouts from the trained policy and demonstration data. @@ -89,7 +94,13 @@ def train_imitation( rng = common.make_rng() custom_logger, log_dir = common.setup_logging() - with common.make_venv() as venv: + post_wrappers = common.setup_video_saving( + base_dir=log_dir, + video_save_interval=video_save_interval, + post_wrappers=None + ) + + with common.make_venv(post_wrappers=post_wrappers) as venv: imit_policy = make_policy(venv, agent_path=agent_path) expert_trajs: Optional[Sequence[types.Trajectory]] = None diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index d871a2fcf..15abcfd76 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -13,7 +13,6 @@ from stable_baselines3.common import type_aliases import gym -import imitation.util.video_wrapper as video_wrapper from imitation.algorithms import preference_comparisons from imitation.data import types from imitation.policies import serialize @@ -85,6 +84,7 @@ def train_preference_comparisons( fragmenter_kwargs: Mapping[str, Any], allow_variable_horizon: bool, checkpoint_interval: int, + video_save_interval: int, query_schedule: Union[str, type_aliases.Schedule], ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. @@ -140,6 +140,10 @@ def train_preference_comparisons( trajectory_generator contains a policy) every `checkpoint_interval` iterations and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. query_schedule: one of ("constant", "hyperbolic", "inverse_quadratic"). A function indicating how the total number of preference queries should be allocated to each iteration. "hyperbolic" and "inverse_quadratic" @@ -154,16 +158,16 @@ def train_preference_comparisons( """ custom_logger, log_dir = common.setup_logging() checkpoint_dir = log_dir / "checkpoints" - video_dir = checkpoint_dir / "videos" checkpoint_dir.mkdir(parents=True, exist_ok=True) - video_dir.mkdir(parents=True, exist_ok=True) rng = common.make_rng() - post_wrappers = None - if checkpoint_interval > 0: - post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] - + post_wrappers = common.setup_video_saving( + base_dir=checkpoint_dir, + video_save_interval=video_save_interval, + post_wrappers=None + ) + with common.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( @@ -264,7 +268,7 @@ def save_callback(iteration_num): if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: save_checkpoint( trainer=main_trainer, - save_path=log_dir / "checkpoints" / f"{iteration_num:04d}", + save_path=checkpoint_dir / f"{iteration_num:04d}", allow_save_policy=bool(trajectory_path is None), ) @@ -286,7 +290,7 @@ def save_callback(iteration_num): if checkpoint_interval >= 0: save_checkpoint( trainer=main_trainer, - save_path=log_dir / "checkpoints" / "final", + save_path=checkpoint_dir / "final", allow_save_policy=bool(trajectory_path is None), ) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 9d237304d..d1e1fd7f1 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -38,10 +38,10 @@ def train_rl( rollout_save_n_timesteps: Optional[int], rollout_save_n_episodes: Optional[int], policy_save_interval: int, + video_save_interval: int, policy_save_final: bool, agent_path: Optional[str], ) -> Mapping[str, float]: - """Trains an expert policy from scratch and saves the rollouts and policy. Checkpoints: At applicable training steps `step` (where step is either an integer or @@ -77,6 +77,10 @@ def train_rl( policy_save_interval: The number of training updates between in between intermediate rollout saves. If the argument is nonpositive, then don't save intermediate updates. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. @@ -87,17 +91,16 @@ def train_rl( custom_logger, log_dir = common.setup_logging() rollout_dir = log_dir / "rollouts" policy_dir = log_dir / "policies" - video_dir = log_dir / "videos" rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) - video_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - if policy_save_interval > 0: - post_wrappers.append( - video_wrapper.video_wrapper_factory(video_dir, policy_save_interval) - ) + post_wrappers = common.setup_video_saving( + base_dir=log_dir, + video_save_interval=video_save_interval, + post_wrappers=post_wrappers + ) with common.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index e24469746..5717f2de1 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -13,7 +13,7 @@ class VideoWrapper(gym.Wrapper): video_recorder: Optional[video_recorder.VideoRecorder] single_video: bool directory: pathlib.Path - cadence: int + video_save_interval: int should_record: bool step_count: int @@ -22,7 +22,7 @@ def __init__( env: gym.Env, directory: pathlib.Path, single_video: bool = True, - cadence: int = 1, + video_save_interval: int = 1, ): """Builds a VideoWrapper. @@ -34,17 +34,17 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. - cadence: the video wrapper will save a video of the next episode that - begins after every Nth step. So if cadence=100 and each episode has - 30 steps, it will record the 4th episode(first to start after - step_count=100) and then the 7thepisode (first to start after - step_count=200). + video_save_interval: the video wrapper will save a video of the next + episode that begins after every Nth step. So if + video_save_interval=100 and each episode has 30 steps, it will record + the 4th episode(first to start after step_count=100) and then the 7th + episode (first to start after step_count=200). """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video - self.cadence = cadence + self.video_save_interval = video_save_interval self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) @@ -64,7 +64,8 @@ def _reset_video_recorder(self) -> None: self.video_recorder.close() self.video_recorder = None - if self.video_recorder is None and (self.should_record or self.step_count % self.cadence == 0): + if self.video_recorder is None and \ + (self.should_record or self.step_count % self.video_save_interval == 0): # No video recorder -- start a new one. self.video_recorder = video_recorder.VideoRecorder( env=self.env, @@ -81,7 +82,7 @@ def reset(self): def step(self, action): res = self.env.step(action) self.step_count += 1 - if self.step_count % self.cadence == 0: + if self.step_count % self.video_save_interval == 0: self.should_record == 0 if self.video_recorder != None: self.video_recorder.capture_frame() @@ -94,7 +95,10 @@ def close(self) -> None: super().close() -def video_wrapper_factory(video_dir: pathlib.Path, cadence: int, **kwargs) -> Callable: +def video_wrapper_factory( + video_dir: pathlib.Path, + video_save_interval: int, + **kwargs) -> Callable: def f(env: gym.Env, i: int) -> VideoWrapper: """ Returns a wrapper around a gym environment records a video if and only if i is 0 @@ -104,6 +108,10 @@ def f(env: gym.Env, i: int) -> VideoWrapper: i: the index of the environment. This is to make the video wrapper compatible with vectorized environments. Only environments with i=0 actually attach the VideoWrapper """ - - return VideoWrapper(env, directory=video_dir, cadence=cadence, **kwargs) if i == 0 else env + return VideoWrapper( + env, + directory=video_dir, + video_save_interval=video_save_interval, + **kwargs + ) if i == 0 else env return f diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 24ede27f9..c8ab005fa 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -72,9 +72,9 @@ @pytest.fixture(autouse=True) def sacred_capture_use_sys(): - """Set Sacred capture mode to "sys" because default "fd" option leads to error. See https://github.com/IDSIA/sacred/issues/289. + Yields: None after setting capture mode; restores it after yield. """ @@ -600,7 +600,9 @@ def test_train_adversarial_algorithm_value_error(tmpdir): def test_transfer_learning(tmpdir: str) -> None: """Transfer learning smoke test. + Saves a dummy AIRL test reward, then loads it for transfer learning. + Args: tmpdir: Temporary directory to save results to. """ @@ -645,9 +647,9 @@ def test_preference_comparisons_transfer_learning( tmpdir: str, named_configs_dict: Mapping[str, List[str]], ) -> None: - """Transfer learning smoke test. Saves a preference comparisons ensemble reward, then loads it for transfer learning. + Args: tmpdir: Temporary directory to save results to. named_configs_dict: Named configs for preference_comparisons and rl. @@ -954,18 +956,22 @@ def test_convert_trajs(tmpdir: str): VIDEO_PATH_DICT = dict( rl=lambda d: d / "videos", adversarial=lambda d: d / "checkpoints" / "videos", - pc=lambda d: d / "checkpoints" / "videos" + pc=lambda d: d / "checkpoints" / "videos", + bc=lambda d: d / "videos" ) def _check_video_exists(log_dir, algo): video_dir = VIDEO_PATH_DICT[algo](log_dir) - assert os.path.exists(video_dir) - assert VIDEO_FILE_PATH in os.listdir(video_dir) + video_file = video_dir / VIDEO_FILE_PATH + assert video_dir.exists() + assert video_file.exists() +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") def test_train_rl_video_saving(tmpdir): """Smoke test for imitation.scripts.train_rl.""" config_updates = dict( - common=dict(log_root=tmpdir) + common=dict(log_root=tmpdir), + video_save_interval=1, ) run = train_rl.train_rl_ex.run( named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], @@ -975,13 +981,14 @@ def test_train_rl_video_saving(tmpdir): assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "rl") +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") def test_train_adversarial_video_saving(tmpdir): """Smoke test for imitation.scripts.train_adversarial.""" named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] config_updates = dict( common=dict(log_root=tmpdir), demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), - checkpoint_interval=1 + video_save_interval=1, ) run = train_adversarial.train_adversarial_ex.run( command_name="gail", @@ -991,10 +998,11 @@ def test_train_adversarial_video_saving(tmpdir): assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "adversarial") +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") def test_train_preference_comparisons_video_saving(tmpdir): config_updates = dict( common=dict(log_root=tmpdir), - checkpoint_interval=1 + video_save_interval=1, ) run = train_preference_comparisons.train_preference_comparisons_ex.run( named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], @@ -1002,3 +1010,18 @@ def test_train_preference_comparisons_video_saving(tmpdir): ) assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "pc") + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") +def test_train_bc_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + video_save_interval=1, + ) + run = train_imitation.train_imitation_ex.run( + command_name="bc", + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "bc") From 975124808f53ec565f8f095b7bee343ff02dd6a1 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Fri, 28 Oct 2022 23:37:35 -0400 Subject: [PATCH 04/11] fix single-video bug in video wrapper --- src/imitation/scripts/eval_policy.py | 6 +++++- src/imitation/util/video_wrapper.py | 11 ++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 06eab4820..1f47e6b38 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -44,7 +44,11 @@ def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) - return video_wrapper.VideoWrapper(env, directory=directory, **kwargs) + return video_wrapper.VideoWrapper( + env, + single_video=False, + directory=directory, + **kwargs) return f diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 5717f2de1..12b33bf58 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -21,7 +21,7 @@ def __init__( self, env: gym.Env, directory: pathlib.Path, - single_video: bool = True, + single_video: bool = False, video_save_interval: int = 1, ): """Builds a VideoWrapper. @@ -38,7 +38,8 @@ def __init__( episode that begins after every Nth step. So if video_save_interval=100 and each episode has 30 steps, it will record the 4th episode(first to start after step_count=100) and then the 7th - episode (first to start after step_count=200). + episode (first to start after step_count=200). If single_video is true, + then this value does not apply as a single video is recorded throughout """ super().__init__(env) self.episode_id = 0 @@ -48,7 +49,7 @@ def __init__( self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) - self.should_record = False + self.should_record = self.single_video self.step_count = 0 def _reset_video_recorder(self) -> None: @@ -72,7 +73,7 @@ def _reset_video_recorder(self) -> None: base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) - self.should_record = False + self.should_record = self.single_video def reset(self): self._reset_video_recorder() @@ -83,7 +84,7 @@ def step(self, action): res = self.env.step(action) self.step_count += 1 if self.step_count % self.video_save_interval == 0: - self.should_record == 0 + self.should_record = True if self.video_recorder != None: self.video_recorder.capture_frame() return res From b79d863226ae6928bedefac9aec313a78ecc214e Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sat, 29 Oct 2022 16:55:09 -0400 Subject: [PATCH 05/11] fix linter errors --- src/imitation/scripts/common/common.py | 9 +++----- .../scripts/config/train_imitation.py | 2 +- src/imitation/scripts/train_adversarial.py | 2 -- src/imitation/scripts/train_imitation.py | 3 +-- .../scripts/train_preference_comparisons.py | 8 ++----- src/imitation/scripts/train_rl.py | 9 +++++--- src/imitation/util/video_wrapper.py | 21 ++++++++++++------- tests/scripts/test_scripts.py | 19 ++++++++++++----- 8 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 98a9f686c..d6add6762 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -134,6 +134,7 @@ def setup_logging( ) return custom_logger, log_dir + @common_ingredient.capture def setup_video_saving( _run, @@ -162,16 +163,13 @@ def setup_video_saving( new list with just the video wrapper. If the video_save_interval is <=0, it will just return the inputted post_wrapper """ - if video_save_interval > 0: - video_dir = base_dir / "videos" video_dir.mkdir(parents=True, exist_ok=True) - post_wrappers_copy = [wrapper for wrapper in post_wrappers] \ - if post_wrappers != None else [] + post_wrappers_copy = list(post_wrappers) if post_wrappers is not None else [] post_wrappers_copy.append( - video_wrapper.video_wrapper_factory(video_dir, video_save_interval) + video_wrapper.video_wrapper_factory(video_dir, video_save_interval), ) return post_wrappers_copy @@ -179,7 +177,6 @@ def setup_video_saving( return post_wrappers - @contextlib.contextmanager @common_ingredient.capture def make_venv( diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index 71f0d6f6b..6d8360358 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -38,7 +38,7 @@ def config(): total_timesteps=1e5, ) agent_path = None # Path to load agent from, optional. - video_save_interval = 0 # <=0 means no saving + video_save_interval = 0 # <=0 means no saving @train_imitation_ex.named_config diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 473aac704..5dcb534e2 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -9,7 +9,6 @@ import torch as th from sacred.observers import FileStorageObserver -import imitation.util.video_wrapper as video_wrapper from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common from imitation.algorithms.adversarial import gail as gail_algo @@ -125,7 +124,6 @@ def train_adversarial( post_wrappers = common_config.setup_video_saving( base_dir=checkpoint_dir, video_save_interval=video_save_interval, - post_wrappers=None ) with common_config.make_venv(post_wrappers=post_wrappers) as venv: diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 8fd9b9132..4276b710a 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -71,7 +71,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], - video_save_interval: int + video_save_interval: int, ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. @@ -97,7 +97,6 @@ def train_imitation( post_wrappers = common.setup_video_saving( base_dir=log_dir, video_save_interval=video_save_interval, - post_wrappers=None ) with common.make_venv(post_wrappers=post_wrappers) as venv: diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 15abcfd76..d5a754ccd 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -12,7 +12,6 @@ from sacred.observers import FileStorageObserver from stable_baselines3.common import type_aliases -import gym from imitation.algorithms import preference_comparisons from imitation.data import types from imitation.policies import serialize @@ -22,9 +21,6 @@ from imitation.scripts.config.train_preference_comparisons import ( train_preference_comparisons_ex, ) -import imitation.util.video_wrapper as video_wrapper - - def save_model( agent_trainer: preference_comparisons.AgentTrainer, @@ -164,8 +160,7 @@ def train_preference_comparisons( post_wrappers = common.setup_video_saving( base_dir=checkpoint_dir, - video_save_interval=video_save_interval, - post_wrappers=None + video_save_interval=video_save_interval ) with common.make_venv(post_wrappers=post_wrappers) as venv: @@ -305,5 +300,6 @@ def main_console(): train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() + if __name__ == "__main__": # pragma: no cover main_console() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index d1e1fd7f1..595f3c0db 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -1,4 +1,5 @@ """Uses RL to train a policy from scratch, saving rollouts and policy. + This can be used: 1. To train a policy on a ground-truth reward function, as a source of synthetic "expert" demonstrations to train IRL or imitation learning @@ -16,7 +17,6 @@ from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecNormalize -import imitation.util.video_wrapper as video_wrapper from imitation.data import rollout, types, wrappers from imitation.policies import serialize from imitation.rewards.reward_wrapper import RewardVecEnvWrapper @@ -43,11 +43,13 @@ def train_rl( agent_path: Optional[str], ) -> Mapping[str, float]: """Trains an expert policy from scratch and saves the rollouts and policy. + Checkpoints: At applicable training steps `step` (where step is either an integer or "final"): - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. + Args: total_timesteps: Number of training timesteps in `model.learn()`. normalize_reward: Applies normalization and clipping to the reward function by @@ -84,6 +86,7 @@ def train_rl( policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. + Returns: The return value of `rollout_stats()` using the final policy. """ @@ -99,7 +102,7 @@ def train_rl( post_wrappers = common.setup_video_saving( base_dir=log_dir, video_save_interval=video_save_interval, - post_wrappers=post_wrappers + post_wrappers=post_wrappers, ) with common.make_venv(post_wrappers=post_wrappers) as venv: @@ -172,4 +175,4 @@ def main_console(): if __name__ == "__main__": # pragma: no cover - main_console() \ No newline at end of file + main_console() diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 12b33bf58..d8e3f2168 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,8 +1,9 @@ """Wrapper to record rendered video frames from an environment.""" import pathlib -from typing import Optional, Callable import gym + +from typing import Optional, Callable from gym.wrappers.monitoring import video_recorder @@ -65,8 +66,9 @@ def _reset_video_recorder(self) -> None: self.video_recorder.close() self.video_recorder = None - if self.video_recorder is None and \ - (self.should_record or self.step_count % self.video_save_interval == 0): + on_interval_boundary = self.step_count % self.video_save_interval == 0 + time_to_record = self.should_record or on_interval_boundary + if self.video_recorder is None and time_to_record: # No video recorder -- start a new one. self.video_recorder = video_recorder.VideoRecorder( env=self.env, @@ -85,7 +87,7 @@ def step(self, action): self.step_count += 1 if self.step_count % self.video_save_interval == 0: self.should_record = True - if self.video_recorder != None: + if self.video_recorder is not None: self.video_recorder.capture_frame() return res @@ -99,8 +101,9 @@ def close(self) -> None: def video_wrapper_factory( video_dir: pathlib.Path, video_save_interval: int, - **kwargs) -> Callable: - def f(env: gym.Env, i: int) -> VideoWrapper: + **kwargs + ) -> Callable: + def f(env: gym.Env, i: int) -> VideoWrapper: """ Returns a wrapper around a gym environment records a video if and only if i is 0 @@ -108,11 +111,15 @@ def f(env: gym.Env, i: int) -> VideoWrapper: env: the environment to be wrapped around i: the index of the environment. This is to make the video wrapper compatible with vectorized environments. Only environments with i=0 actually attach the VideoWrapper + + Returns: + A video wrapper around the original environment if the index is 0. + Otherwise, the original environment is just returned. """ return VideoWrapper( env, directory=video_dir, video_save_interval=video_save_interval, - **kwargs + **kwargs, ) if i == 0 else env return f diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index c8ab005fa..d9e8cea65 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -1,4 +1,5 @@ """Smoke tests for CLI programs in `imitation.scripts.*`. + Every test in this file should use `parallel=False` to turn off multiprocessing because codecov might interact poorly with multiprocessing. The 'fast' named_config for each experiment implicitly sets parallel=False. @@ -73,6 +74,7 @@ @pytest.fixture(autouse=True) def sacred_capture_use_sys(): """Set Sacred capture mode to "sys" because default "fd" option leads to error. + See https://github.com/IDSIA/sacred/issues/289. Yields: @@ -648,6 +650,7 @@ def test_preference_comparisons_transfer_learning( named_configs_dict: Mapping[str, List[str]], ) -> None: """Transfer learning smoke test. + Saves a preference comparisons ensemble reward, then loads it for transfer learning. Args: @@ -951,22 +954,25 @@ def test_convert_trajs(tmpdir: str): for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + # Change the following if the file structure of checkpoints changed. VIDEO_FILE_PATH = "video.{:06}.mp4".format(0) VIDEO_PATH_DICT = dict( rl=lambda d: d / "videos", adversarial=lambda d: d / "checkpoints" / "videos", pc=lambda d: d / "checkpoints" / "videos", - bc=lambda d: d / "videos" + bc=lambda d: d / "videos", ) + def _check_video_exists(log_dir, algo): video_dir = VIDEO_PATH_DICT[algo](log_dir) video_file = video_dir / VIDEO_FILE_PATH assert video_dir.exists() assert video_file.exists() -@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") def test_train_rl_video_saving(tmpdir): """Smoke test for imitation.scripts.train_rl.""" config_updates = dict( @@ -981,7 +987,8 @@ def test_train_rl_video_saving(tmpdir): assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "rl") -@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") def test_train_adversarial_video_saving(tmpdir): """Smoke test for imitation.scripts.train_adversarial.""" named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] @@ -998,7 +1005,8 @@ def test_train_adversarial_video_saving(tmpdir): assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "adversarial") -@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") def test_train_preference_comparisons_video_saving(tmpdir): config_updates = dict( common=dict(log_root=tmpdir), @@ -1011,7 +1019,8 @@ def test_train_preference_comparisons_video_saving(tmpdir): assert run.status == "COMPLETED" _check_video_exists(run.config["common"]["log_dir"], "pc") -@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg takes a long time to install") + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") def test_train_bc_video_saving(tmpdir): config_updates = dict( common=dict(log_root=tmpdir), From db231035561a417aec9be3261eace27fbd556092 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 19:16:02 -0400 Subject: [PATCH 06/11] fix lint and mytype issues --- src/imitation/scripts/common/common.py | 4 +- src/imitation/scripts/eval_policy.py | 6 +-- .../scripts/train_preference_comparisons.py | 3 +- src/imitation/util/video_wrapper.py | 39 +++++++++++-------- tests/scripts/test_scripts.py | 2 +- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index d6add6762..9a4cfcdac 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,19 +3,19 @@ import contextlib import logging import pathlib -from typing import Any, Generator, Mapping, Sequence, Tuple, Union, Optional, Callable +from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union import gym import numpy as np import sacred from stable_baselines3.common import vec_env +import imitation.util.video_wrapper as video_wrapper from imitation.data import types from imitation.scripts.common import wb from imitation.util import logger as imit_logger from imitation.util import sacred as sacred_util from imitation.util import util -import imitation.util.video_wrapper as video_wrapper common_ingredient = sacred.Ingredient("common", ingredients=[wb.wandb_ingredient]) logger = logging.getLogger(__name__) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 8fe317dea..bc7158589 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -46,10 +46,8 @@ def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) return video_wrapper.VideoWrapper( - env, - single_video=False, - directory=directory, - **kwargs) + env, single_video=False, directory=directory, **kwargs + ) return f diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index d5a754ccd..697d645e3 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -22,6 +22,7 @@ train_preference_comparisons_ex, ) + def save_model( agent_trainer: preference_comparisons.AgentTrainer, save_path: pathlib.Path, @@ -160,7 +161,7 @@ def train_preference_comparisons( post_wrappers = common.setup_video_saving( base_dir=checkpoint_dir, - video_save_interval=video_save_interval + video_save_interval=video_save_interval, ) with common.make_venv(post_wrappers=post_wrappers) as venv: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index d8e3f2168..575d35987 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,17 +1,17 @@ """Wrapper to record rendered video frames from an environment.""" import pathlib -import gym +from typing import Callable, Optional -from typing import Optional, Callable -from gym.wrappers.monitoring import video_recorder +import gym +from gym.wrappers.monitoring.video_recorder import VideoRecorder class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" episode_id: int - video_recorder: Optional[video_recorder.VideoRecorder] + video_recorder: Optional[VideoRecorder] single_video: bool directory: pathlib.Path video_save_interval: int @@ -70,7 +70,7 @@ def _reset_video_recorder(self) -> None: time_to_record = self.should_record or on_interval_boundary if self.video_recorder is None and time_to_record: # No video recorder -- start a new one. - self.video_recorder = video_recorder.VideoRecorder( + self.video_recorder = VideoRecorder( env=self.env, base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, @@ -101,25 +101,30 @@ def close(self) -> None: def video_wrapper_factory( video_dir: pathlib.Path, video_save_interval: int, - **kwargs - ) -> Callable: + **kwargs, +) -> Callable: def f(env: gym.Env, i: int) -> VideoWrapper: - """ - Returns a wrapper around a gym environment records a video if and only if i is 0 + """Returns a wrapper around a gym environment records a video if and only if i is 0. Args: env: the environment to be wrapped around - i: the index of the environment. This is to make the video wrapper compatible with - vectorized environments. Only environments with i=0 actually attach the VideoWrapper + i: the index of the environment. This is to make the video wrapper + compatible with vectorized environments. Only environments with + i=0 actually attach the VideoWrapper Returns: A video wrapper around the original environment if the index is 0. Otherwise, the original environment is just returned. """ - return VideoWrapper( - env, - directory=video_dir, - video_save_interval=video_save_interval, - **kwargs, - ) if i == 0 else env + return ( + VideoWrapper( + env, + directory=video_dir, + video_save_interval=video_save_interval, + **kwargs, + ) + if i == 0 + else env + ) + return f diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index b47c73401..b49495b0f 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -74,7 +74,7 @@ @pytest.fixture(autouse=True) def sacred_capture_use_sys(): """Set Sacred capture mode to "sys" because default "fd" option leads to error. - + See https://github.com/IDSIA/sacred/issues/289. Yields: From 1b1c990ca7e74d768cf78af1a68672b167af2ae4 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 20:25:58 -0400 Subject: [PATCH 07/11] fix minor issues --- src/imitation/scripts/eval_policy.py | 2 +- src/imitation/scripts/train_rl.py | 2 +- src/imitation/util/video_wrapper.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index bc7158589..266b6689c 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -46,7 +46,7 @@ def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) return video_wrapper.VideoWrapper( - env, single_video=False, directory=directory, **kwargs + env, single_video=False, directory=directory, **kwargs, ) return f diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 595f3c0db..d25e6ecb5 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -1,6 +1,6 @@ """Uses RL to train a policy from scratch, saving rollouts and policy. -This can be used: + This can be used: 1. To train a policy on a ground-truth reward function, as a source of synthetic "expert" demonstrations to train IRL or imitation learning algorithms. diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 575d35987..b60daa9db 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -22,7 +22,7 @@ def __init__( self, env: gym.Env, directory: pathlib.Path, - single_video: bool = False, + single_video: bool = True, video_save_interval: int = 1, ): """Builds a VideoWrapper. @@ -120,6 +120,7 @@ def f(env: gym.Env, i: int) -> VideoWrapper: VideoWrapper( env, directory=video_dir, + single_video=True, video_save_interval=video_save_interval, **kwargs, ) From 720e245a1b0ddb6dd98162912df335fff025f862 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 21:22:49 -0400 Subject: [PATCH 08/11] fix eval policy --- src/imitation/scripts/eval_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 266b6689c..63bbada0a 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -46,7 +46,7 @@ def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) return video_wrapper.VideoWrapper( - env, single_video=False, directory=directory, **kwargs, + env, directory=directory, **kwargs, ) return f From 63850dd25dd018163a2859365923886a3423a91a Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 21:31:45 -0400 Subject: [PATCH 09/11] fix doctest --- src/imitation/scripts/train_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index d25e6ecb5..595f3c0db 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -1,6 +1,6 @@ """Uses RL to train a policy from scratch, saving rollouts and policy. - This can be used: +This can be used: 1. To train a policy on a ground-truth reward function, as a source of synthetic "expert" demonstrations to train IRL or imitation learning algorithms. From 9335d32dcaa0dc4f90c25e08b6a701a00691cc3c Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 22:47:49 -0400 Subject: [PATCH 10/11] fix doctests --- src/imitation/scripts/eval_policy.py | 4 +++- src/imitation/scripts/train_rl.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 63bbada0a..bfa6e3815 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -46,7 +46,9 @@ def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) return video_wrapper.VideoWrapper( - env, directory=directory, **kwargs, + env, + directory=directory, + **kwargs, ) return f diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 595f3c0db..b66cd515d 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -45,8 +45,8 @@ def train_rl( """Trains an expert policy from scratch and saves the rollouts and policy. Checkpoints: - At applicable training steps `step` (where step is either an integer or - "final"): + At applicable training steps `step` (where step is either an integer or + "final"): - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. From 771448448018f983b145d9fb6507111e5acecf12 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Sun, 30 Oct 2022 23:01:22 -0400 Subject: [PATCH 11/11] fix black errors in tests --- tests/scripts/test_scripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index b49495b0f..6a554f9d6 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -979,7 +979,7 @@ def test_train_rl_video_saving(tmpdir): config_updates = dict( common=dict(log_root=tmpdir), video_save_interval=1, - ) + ) run = train_rl.train_rl_ex.run( named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], config_updates=config_updates, @@ -997,7 +997,7 @@ def test_train_adversarial_video_saving(tmpdir): common=dict(log_root=tmpdir), demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), video_save_interval=1, - ) + ) run = train_adversarial.train_adversarial_ex.run( command_name="gail", named_configs=named_configs,