-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save videos during training #597
base: master
Are you sure you want to change the base?
Changes from all commits
c41345b
a4211ff
7f5b803
5e6ab25
9751248
b79d863
9074f3c
db23103
1b1c990
720e245
63850dd
9335d32
7714484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,14 @@ | |
import contextlib | ||
import logging | ||
import pathlib | ||
from typing import Any, Generator, Mapping, Sequence, Tuple, Union | ||
from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union | ||
|
||
import gym | ||
import numpy as np | ||
import sacred | ||
from stable_baselines3.common import vec_env | ||
|
||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.data import types | ||
from imitation.scripts.common import wb | ||
from imitation.util import logger as imit_logger | ||
|
@@ -133,6 +135,48 @@ def setup_logging( | |
return custom_logger, log_dir | ||
|
||
|
||
@common_ingredient.capture | ||
def setup_video_saving( | ||
_run, | ||
base_dir: pathlib.Path, | ||
video_save_interval: int, | ||
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None, | ||
) -> Optional[Sequence[Callable[[gym.Env, int], gym.Env]]]: | ||
"""Adds video saving to the existing post wrappers. | ||
|
||
Args: | ||
base_dir: the videos will be saved to a videos subdirectory under | ||
this base directory | ||
video_save_interval: the video wrapper will save a video of the next | ||
episode that begins after every Nth step. So if | ||
video_save_interval=100 and each episode has 30 steps, it will record | ||
the 4th episode(first to start after step_count=100) and then the 7th | ||
episode (first to start after step_count=200). | ||
post_wrappers: If specified, iteratively wraps each environment with each | ||
of the wrappers specified in the sequence. The argument should be a | ||
Callable accepting two arguments, the Env to be wrapped and the | ||
environment index, and returning the wrapped Env. | ||
|
||
Returns: | ||
A new post_wrapper list with the video saving wrapper appended to the | ||
existing list. If the existing post wrapper was null, it will create a | ||
new list with just the video wrapper. If the video_save_interval is <=0, | ||
it will just return the inputted post_wrapper | ||
""" | ||
if video_save_interval > 0: | ||
video_dir = base_dir / "videos" | ||
video_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
post_wrappers_copy = list(post_wrappers) if post_wrappers is not None else [] | ||
post_wrappers_copy.append( | ||
video_wrapper.video_wrapper_factory(video_dir, video_save_interval), | ||
) | ||
|
||
return post_wrappers_copy | ||
|
||
return post_wrappers | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally have two lines between functions not four (though I guess more is OK if you want to emphasize separation between them) |
||
|
||
@contextlib.contextmanager | ||
@common_ingredient.capture | ||
def make_venv( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ def defaults(): | |
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config | ||
|
||
checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) | ||
video_save_interval = 0 # Number of steps before saving video (<=0 disables) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll note you could put this in |
||
agent_path = None # Path to load agent from, optional. | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,15 +38,15 @@ def train_rl( | |
rollout_save_n_timesteps: Optional[int], | ||
rollout_save_n_episodes: Optional[int], | ||
policy_save_interval: int, | ||
video_save_interval: int, | ||
policy_save_final: bool, | ||
agent_path: Optional[str], | ||
) -> Mapping[str, float]: | ||
"""Trains an expert policy from scratch and saves the rollouts and policy. | ||
|
||
Checkpoints: | ||
At applicable training steps `step` (where step is either an integer or | ||
"final"): | ||
|
||
At applicable training steps `step` (where step is either an integer or | ||
"final"): | ||
- Policies are saved to `{log_dir}/policies/{step}/`. | ||
- Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep this whitespace between sections to be compliant with Google docstring style which we've adopted in this project: https://google.github.io/styleguide/pyguide.html#doc-function-raises Removing whitespace before the list above seems fine though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed now |
||
|
@@ -79,6 +79,10 @@ def train_rl( | |
policy_save_interval: The number of training updates between in between | ||
intermediate rollout saves. If the argument is nonpositive, then | ||
don't save intermediate updates. | ||
video_save_interval: The number of steps to take before saving a video. | ||
After that step count is reached, the step count is reset and the next | ||
episode will be recorded in full. Empty or negative values means no | ||
video is saved. | ||
policy_save_final: If True, then save the policy right after training is | ||
finished. | ||
agent_path: Path to load warm-started agent. | ||
|
@@ -94,6 +98,13 @@ def train_rl( | |
policy_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] | ||
|
||
post_wrappers = common.setup_video_saving( | ||
base_dir=log_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=post_wrappers, | ||
) | ||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
callback_objs = [] | ||
if reward_type is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence[Callable[[gym.Env, int], gym.Env]]
is quite long and repeated -- maybe introduce a type alias for it?