diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 72d44f2f4..9a4cfcdac 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,12 +3,14 @@ import contextlib import logging import pathlib -from typing import Any, Generator, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union +import gym import numpy as np import sacred from stable_baselines3.common import vec_env +import imitation.util.video_wrapper as video_wrapper from imitation.data import types from imitation.scripts.common import wb from imitation.util import logger as imit_logger @@ -133,6 +135,48 @@ def setup_logging( return custom_logger, log_dir +@common_ingredient.capture +def setup_video_saving( + _run, + base_dir: pathlib.Path, + video_save_interval: int, + post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None, +) -> Optional[Sequence[Callable[[gym.Env, int], gym.Env]]]: + """Adds video saving to the existing post wrappers. + + Args: + base_dir: the videos will be saved to a videos subdirectory under + this base directory + video_save_interval: the video wrapper will save a video of the next + episode that begins after every Nth step. So if + video_save_interval=100 and each episode has 30 steps, it will record + the 4th episode(first to start after step_count=100) and then the 7th + episode (first to start after step_count=200). + post_wrappers: If specified, iteratively wraps each environment with each + of the wrappers specified in the sequence. The argument should be a + Callable accepting two arguments, the Env to be wrapped and the + environment index, and returning the wrapped Env. + + Returns: + A new post_wrapper list with the video saving wrapper appended to the + existing list. If the existing post wrapper was null, it will create a + new list with just the video wrapper. If the video_save_interval is <=0, + it will just return the inputted post_wrapper + """ + if video_save_interval > 0: + video_dir = base_dir / "videos" + video_dir.mkdir(parents=True, exist_ok=True) + + post_wrappers_copy = list(post_wrappers) if post_wrappers is not None else [] + post_wrappers_copy.append( + video_wrapper.video_wrapper_factory(video_dir, video_save_interval), + ) + + return post_wrappers_copy + + return post_wrappers + + @contextlib.contextmanager @common_ingredient.capture def make_venv( diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index 2bd3759a2..67ec0d785 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -154,6 +154,7 @@ def make_rl_algo( ) else: raise TypeError(f"Unsupported RL algorithm '{rl_cls}'") + rl_algo = rl_cls( policy=train["policy_cls"], # Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass. diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 3183ac9f6..3779350eb 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -30,6 +30,7 @@ def defaults(): algorithm_specific = {} # algorithm_specific[algorithm] is merged with config checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) agent_path = None # Path to load agent from, optional. diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index c2466a936..6d8360358 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -38,6 +38,7 @@ def config(): total_timesteps=1e5, ) agent_path = None # Path to load agent from, optional. + video_save_interval = 0 # <=0 means no saving @train_imitation_ex.named_config diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..9f0d554fd 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -58,6 +58,7 @@ def train_defaults(): allow_variable_horizon = False checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) query_schedule = "hyperbolic" diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index b9ede3165..b5ed922c0 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -30,6 +30,7 @@ def train_rl_defaults(): rollout_save_n_episodes = None # Num episodes saved per file, optional. policy_save_interval = 10000 # Num timesteps between saves (<=0 disables) + video_save_interval = 0 # Number of steps before saving video (<=0 disables) policy_save_final = True # If True, save after training is finished. agent_path = None # Path to load agent from, optional. diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index dee635271..bfa6e3815 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -45,7 +45,11 @@ def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" directory = log_dir / "videos" / str(i) - return video_wrapper.VideoWrapper(env, directory=directory, **kwargs) + return video_wrapper.VideoWrapper( + env, + directory=directory, + **kwargs, + ) return f diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index b84aec720..5dcb534e2 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -71,6 +71,7 @@ def train_adversarial( algorithm_kwargs: Mapping[str, Any], total_timesteps: int, checkpoint_interval: int, + video_save_interval: int, agent_path: Optional[str], ) -> Mapping[str, Mapping[str, float]]: """Train an adversarial-network-based imitation learning algorithm. @@ -93,6 +94,10 @@ def train_adversarial( `checkpoint_interval` rounds and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. agent_path: Path to a directory containing a pre-trained agent. If provided, then the agent will be initialized using this stored policy (warm start). If not provided, then the agent will be initialized using @@ -111,9 +116,17 @@ def train_adversarial( sacred.commands.print_config(_run) custom_logger, log_dir = common_config.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + expert_trajs = demonstrations.get_expert_trajectories() - with common_config.make_venv() as venv: + post_wrappers = common_config.setup_video_saving( + base_dir=checkpoint_dir, + video_save_interval=video_save_interval, + ) + + with common_config.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, @@ -149,14 +162,14 @@ def train_adversarial( def callback(round_num: int, /) -> None: if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") + save(trainer, checkpoint_dir / f"{round_num:05d}") trainer.train(total_timesteps, callback) imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) # Save final artifacts. if checkpoint_interval >= 0: - save(trainer, log_dir / "checkpoints" / "final") + save(trainer, checkpoint_dir / "final") return { "imit_stats": imit_stats, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 09393366e..4276b710a 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -71,6 +71,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], + video_save_interval: int, ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. @@ -82,6 +83,10 @@ def train_imitation( agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. Returns: Statistics for rollouts from the trained policy and demonstration data. @@ -89,7 +94,12 @@ def train_imitation( rng = common.make_rng() custom_logger, log_dir = common.setup_logging() - with common.make_venv() as venv: + post_wrappers = common.setup_video_saving( + base_dir=log_dir, + video_save_interval=video_save_interval, + ) + + with common.make_venv(post_wrappers=post_wrappers) as venv: imit_policy = make_policy(venv, agent_path=agent_path) expert_trajs: Optional[Sequence[types.Trajectory]] = None diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 331a4797a..697d645e3 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -81,6 +81,7 @@ def train_preference_comparisons( fragmenter_kwargs: Mapping[str, Any], allow_variable_horizon: bool, checkpoint_interval: int, + video_save_interval: int, query_schedule: Union[str, type_aliases.Schedule], ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. @@ -136,6 +137,10 @@ def train_preference_comparisons( trajectory_generator contains a policy) every `checkpoint_interval` iterations and after training is complete. If 0, then only save weights after training is complete. If <0, then don't save weights at all. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. query_schedule: one of ("constant", "hyperbolic", "inverse_quadratic"). A function indicating how the total number of preference queries should be allocated to each iteration. "hyperbolic" and "inverse_quadratic" @@ -149,14 +154,23 @@ def train_preference_comparisons( ValueError: Inconsistency between config and deserialized policy normalization. """ custom_logger, log_dir = common.setup_logging() + checkpoint_dir = log_dir / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + rng = common.make_rng() - with common.make_venv() as venv: + post_wrappers = common.setup_video_saving( + base_dir=checkpoint_dir, + video_save_interval=video_save_interval, + ) + + with common.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, update_stats=False, ) + if agent_path is None: agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) else: @@ -250,7 +264,7 @@ def save_callback(iteration_num): if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: save_checkpoint( trainer=main_trainer, - save_path=log_dir / "checkpoints" / f"{iteration_num:04d}", + save_path=checkpoint_dir / f"{iteration_num:04d}", allow_save_policy=bool(trajectory_path is None), ) @@ -272,7 +286,7 @@ def save_callback(iteration_num): if checkpoint_interval >= 0: save_checkpoint( trainer=main_trainer, - save_path=log_dir / "checkpoints" / "final", + save_path=checkpoint_dir / "final", allow_save_policy=bool(trajectory_path is None), ) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index fb7959592..b66cd515d 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -38,15 +38,15 @@ def train_rl( rollout_save_n_timesteps: Optional[int], rollout_save_n_episodes: Optional[int], policy_save_interval: int, + video_save_interval: int, policy_save_final: bool, agent_path: Optional[str], ) -> Mapping[str, float]: """Trains an expert policy from scratch and saves the rollouts and policy. Checkpoints: - At applicable training steps `step` (where step is either an integer or - "final"): - + At applicable training steps `step` (where step is either an integer or + "final"): - Policies are saved to `{log_dir}/policies/{step}/`. - Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. @@ -79,6 +79,10 @@ def train_rl( policy_save_interval: The number of training updates between in between intermediate rollout saves. If the argument is nonpositive, then don't save intermediate updates. + video_save_interval: The number of steps to take before saving a video. + After that step count is reached, the step count is reset and the next + episode will be recorded in full. Empty or negative values means no + video is saved. policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. @@ -94,6 +98,13 @@ def train_rl( policy_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] + + post_wrappers = common.setup_video_saving( + base_dir=log_dir, + video_save_interval=video_save_interval, + post_wrappers=post_wrappers, + ) + with common.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] if reward_type is not None: diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 70190b1fb..c3ca758ba 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -377,7 +377,10 @@ def write( if excluded is not None and "wandb" in excluded: continue - self.wandb_module.log({key: value}, step=step) + if key != "video": + self.wandb_module.log({key: value}, step=step) + else: + self.wandb_module.log({"video": self.wandb_module.Video(value)}) self.wandb_module.log({}, commit=True) def close(self) -> None: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index a59641aa1..b60daa9db 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,25 +1,29 @@ """Wrapper to record rendered video frames from an environment.""" import pathlib -from typing import Optional +from typing import Callable, Optional import gym -from gym.wrappers.monitoring import video_recorder +from gym.wrappers.monitoring.video_recorder import VideoRecorder class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" episode_id: int - video_recorder: Optional[video_recorder.VideoRecorder] + video_recorder: Optional[VideoRecorder] single_video: bool directory: pathlib.Path + video_save_interval: int + should_record: bool + step_count: int def __init__( self, env: gym.Env, directory: pathlib.Path, single_video: bool = True, + video_save_interval: int = 1, ): """Builds a VideoWrapper. @@ -31,14 +35,23 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. + video_save_interval: the video wrapper will save a video of the next + episode that begins after every Nth step. So if + video_save_interval=100 and each episode has 30 steps, it will record + the 4th episode(first to start after step_count=100) and then the 7th + episode (first to start after step_count=200). If single_video is true, + then this value does not apply as a single video is recorded throughout """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video + self.video_save_interval = video_save_interval self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) + self.should_record = self.single_video + self.step_count = 0 def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -53,13 +66,16 @@ def _reset_video_recorder(self) -> None: self.video_recorder.close() self.video_recorder = None - if self.video_recorder is None: + on_interval_boundary = self.step_count % self.video_save_interval == 0 + time_to_record = self.should_record or on_interval_boundary + if self.video_recorder is None and time_to_record: # No video recorder -- start a new one. - self.video_recorder = video_recorder.VideoRecorder( + self.video_recorder = VideoRecorder( env=self.env, base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) + self.should_record = self.single_video def reset(self): self._reset_video_recorder() @@ -68,7 +84,11 @@ def reset(self): def step(self, action): res = self.env.step(action) - self.video_recorder.capture_frame() + self.step_count += 1 + if self.step_count % self.video_save_interval == 0: + self.should_record = True + if self.video_recorder is not None: + self.video_recorder.capture_frame() return res def close(self) -> None: @@ -76,3 +96,36 @@ def close(self) -> None: self.video_recorder.close() self.video_recorder = None super().close() + + +def video_wrapper_factory( + video_dir: pathlib.Path, + video_save_interval: int, + **kwargs, +) -> Callable: + def f(env: gym.Env, i: int) -> VideoWrapper: + """Returns a wrapper around a gym environment records a video if and only if i is 0. + + Args: + env: the environment to be wrapped around + i: the index of the environment. This is to make the video wrapper + compatible with vectorized environments. Only environments with + i=0 actually attach the VideoWrapper + + Returns: + A video wrapper around the original environment if the index is 0. + Otherwise, the original environment is just returned. + """ + return ( + VideoWrapper( + env, + directory=video_dir, + single_video=True, + video_save_interval=video_save_interval, + **kwargs, + ) + if i == 0 + else env + ) + + return f diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 20e9a6559..6a554f9d6 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -954,3 +954,84 @@ def test_convert_trajs(tmpdir: str): assert len(from_pkl) == len(from_npz) for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + + +# Change the following if the file structure of checkpoints changed. +VIDEO_FILE_PATH = "video.{:06}.mp4".format(0) +VIDEO_PATH_DICT = dict( + rl=lambda d: d / "videos", + adversarial=lambda d: d / "checkpoints" / "videos", + pc=lambda d: d / "checkpoints" / "videos", + bc=lambda d: d / "videos", +) + + +def _check_video_exists(log_dir, algo): + video_dir = VIDEO_PATH_DICT[algo](log_dir) + video_file = video_dir / VIDEO_FILE_PATH + assert video_dir.exists() + assert video_file.exists() + + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") +def test_train_rl_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_rl.""" + config_updates = dict( + common=dict(log_root=tmpdir), + video_save_interval=1, + ) + run = train_rl.train_rl_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "rl") + + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + video_save_interval=1, + ) + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "adversarial") + + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") +def test_train_preference_comparisons_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + video_save_interval=1, + ) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "pc") + + +@pytest.mark.skipif(sys.platform == "darwin", reason="ffmpeg instally takes awhile") +def test_train_bc_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + video_save_interval=1, + ) + run = train_imitation.train_imitation_ex.run( + command_name="bc", + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "bc") diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index f3b1a85a9..7bbfbdb63 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -113,6 +113,11 @@ def test_wandb_output_format(): {"_step": 0, "foo": 42, "fizz": 12}, {"_step": 3, "fizz": 21}, ] + + with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): + log_obj.record("video", 42) + log_obj.dump(step=4) + log_obj.close()