-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save videos during training #597
base: master
Are you sure you want to change the base?
Changes from 4 commits
c41345b
a4211ff
7f5b803
5e6ab25
9751248
b79d863
9074f3c
db23103
1b1c990
720e245
63850dd
9335d32
7714484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,8 +3,9 @@ | |||||
import contextlib | ||||||
import logging | ||||||
import pathlib | ||||||
from typing import Any, Generator, Mapping, Sequence, Tuple, Union | ||||||
from typing import Any, Generator, Mapping, Sequence, Tuple, Union, Optional, Callable | ||||||
|
||||||
import gym | ||||||
import numpy as np | ||||||
import sacred | ||||||
from stable_baselines3.common import vec_env | ||||||
|
@@ -14,6 +15,7 @@ | |||||
from imitation.util import logger as imit_logger | ||||||
from imitation.util import sacred as sacred_util | ||||||
from imitation.util import util | ||||||
import imitation.util.video_wrapper as video_wrapper | ||||||
|
||||||
common_ingredient = sacred.Ingredient("common", ingredients=[wb.wandb_ingredient]) | ||||||
logger = logging.getLogger(__name__) | ||||||
|
@@ -132,6 +134,51 @@ def setup_logging( | |||||
) | ||||||
return custom_logger, log_dir | ||||||
|
||||||
@common_ingredient.capture | ||||||
def setup_video_saving( | ||||||
_run, | ||||||
base_dir: pathlib.Path, | ||||||
video_save_interval: int, | ||||||
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None, | ||||||
) -> Optional[Sequence[Callable[[gym.Env, int], gym.Env]]]: | ||||||
"""Adds video saving to the existing post wrappers. | ||||||
|
||||||
Args: | ||||||
base_dir: the videos will be saved to a videos subdirectory under | ||||||
this base directory | ||||||
video_save_interval: the video wrapper will save a video of the next | ||||||
episode that begins after every Nth step. So if | ||||||
video_save_interval=100 and each episode has 30 steps, it will record | ||||||
the 4th episode(first to start after step_count=100) and then the 7th | ||||||
episode (first to start after step_count=200). | ||||||
post_wrappers: If specified, iteratively wraps each environment with each | ||||||
of the wrappers specified in the sequence. The argument should be a | ||||||
Callable accepting two arguments, the Env to be wrapped and the | ||||||
environment index, and returning the wrapped Env. | ||||||
|
||||||
Returns: | ||||||
A new post_wrapper list with the video saving wrapper appended to the | ||||||
existing list. If the existing post wrapper was null, it will create a | ||||||
new list with just the video wrapper. If the video_save_interval is <=0, | ||||||
it will just return the inputted post_wrapper | ||||||
""" | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if video_save_interval > 0: | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
video_dir = base_dir / "videos" | ||||||
video_dir.mkdir(parents=True, exist_ok=True) | ||||||
|
||||||
post_wrappers_copy = [wrapper for wrapper in post_wrappers] \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if post_wrappers != None else [] | ||||||
post_wrappers_copy.append( | ||||||
video_wrapper.video_wrapper_factory(video_dir, video_save_interval) | ||||||
) | ||||||
|
||||||
return post_wrappers_copy | ||||||
|
||||||
return post_wrappers | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally have two lines between functions not four (though I guess more is OK if you want to emphasize separation between them) |
||||||
|
||||||
|
||||||
@contextlib.contextmanager | ||||||
@common_ingredient.capture | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ def defaults(): | |
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config | ||
|
||
checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) | ||
video_save_interval = 0 # Number of steps before saving video (<=0 disables) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll note you could put this in |
||
agent_path = None # Path to load agent from, optional. | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import torch as th | ||
from sacred.observers import FileStorageObserver | ||
|
||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.algorithms.adversarial import airl as airl_algo | ||
from imitation.algorithms.adversarial import common | ||
from imitation.algorithms.adversarial import gail as gail_algo | ||
|
@@ -71,6 +72,7 @@ def train_adversarial( | |
algorithm_kwargs: Mapping[str, Any], | ||
total_timesteps: int, | ||
checkpoint_interval: int, | ||
video_save_interval: int, | ||
agent_path: Optional[str], | ||
) -> Mapping[str, Mapping[str, float]]: | ||
"""Train an adversarial-network-based imitation learning algorithm. | ||
|
@@ -93,6 +95,10 @@ def train_adversarial( | |
`checkpoint_interval` rounds and after training is complete. If 0, | ||
then only save weights after training is complete. If <0, then don't | ||
save weights at all. | ||
video_save_interval: The number of steps to take before saving a video. | ||
After that step count is reached, the step count is reset and the next | ||
episode will be recorded in full. Empty or negative values means no | ||
video is saved. | ||
agent_path: Path to a directory containing a pre-trained agent. If | ||
provided, then the agent will be initialized using this stored policy | ||
(warm start). If not provided, then the agent will be initialized using | ||
|
@@ -111,9 +117,18 @@ def train_adversarial( | |
sacred.commands.print_config(_run) | ||
|
||
custom_logger, log_dir = common_config.setup_logging() | ||
checkpoint_dir = log_dir / "checkpoints" | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
expert_trajs = demonstrations.get_expert_trajectories() | ||
|
||
with common_config.make_venv() as venv: | ||
post_wrappers = common_config.setup_video_saving( | ||
base_dir=checkpoint_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None is the default so you could probably omit this line? |
||
) | ||
|
||
with common_config.make_venv(post_wrappers=post_wrappers) as venv: | ||
reward_net = reward.make_reward_net(venv) | ||
relabel_reward_fn = functools.partial( | ||
reward_net.predict_processed, | ||
|
@@ -149,14 +164,14 @@ def train_adversarial( | |
|
||
def callback(round_num: int, /) -> None: | ||
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: | ||
save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") | ||
save(trainer, checkpoint_dir / f"{round_num:05d}") | ||
|
||
trainer.train(total_timesteps, callback) | ||
imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) | ||
|
||
# Save final artifacts. | ||
if checkpoint_interval >= 0: | ||
save(trainer, log_dir / "checkpoints" / "final") | ||
save(trainer, checkpoint_dir / "final") | ||
|
||
return { | ||
"imit_stats": imit_stats, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,6 +71,7 @@ def train_imitation( | |
dagger: Mapping[str, Any], | ||
use_dagger: bool, | ||
agent_path: Optional[str], | ||
video_save_interval: int | ||
) -> Mapping[str, Mapping[str, float]]: | ||
"""Runs DAgger (if `use_dagger`) or BC (otherwise) training. | ||
|
||
|
@@ -82,14 +83,24 @@ def train_imitation( | |
agent_path: Path to serialized policy. If provided, then load the | ||
policy from this path. Otherwise, make a new policy. | ||
Specify only if policy_cls and policy_kwargs are not specified. | ||
video_save_interval: The number of steps to take before saving a video. | ||
After that step count is reached, the step count is reset and the next | ||
episode will be recorded in full. Empty or negative values means no | ||
video is saved. | ||
|
||
Returns: | ||
Statistics for rollouts from the trained policy and demonstration data. | ||
""" | ||
rng = common.make_rng() | ||
custom_logger, log_dir = common.setup_logging() | ||
|
||
with common.make_venv() as venv: | ||
post_wrappers = common.setup_video_saving( | ||
base_dir=log_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None is the default so you could probably omit this line? |
||
) | ||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
imit_policy = make_policy(venv, agent_path=agent_path) | ||
|
||
expert_trajs: Optional[Sequence[types.Trajectory]] = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from sacred.observers import FileStorageObserver | ||
from stable_baselines3.common import type_aliases | ||
|
||
import gym | ||
from imitation.algorithms import preference_comparisons | ||
from imitation.data import types | ||
from imitation.policies import serialize | ||
|
@@ -21,6 +22,8 @@ | |
from imitation.scripts.config.train_preference_comparisons import ( | ||
train_preference_comparisons_ex, | ||
) | ||
import imitation.util.video_wrapper as video_wrapper | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicate import? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch -- fixed |
||
|
||
|
||
|
||
def save_model( | ||
|
@@ -81,6 +84,7 @@ def train_preference_comparisons( | |
fragmenter_kwargs: Mapping[str, Any], | ||
allow_variable_horizon: bool, | ||
checkpoint_interval: int, | ||
video_save_interval: int, | ||
query_schedule: Union[str, type_aliases.Schedule], | ||
) -> Mapping[str, Any]: | ||
"""Train a reward model using preference comparisons. | ||
|
@@ -136,6 +140,10 @@ def train_preference_comparisons( | |
trajectory_generator contains a policy) every `checkpoint_interval` | ||
iterations and after training is complete. If 0, then only save weights | ||
after training is complete. If <0, then don't save weights at all. | ||
video_save_interval: The number of steps to take before saving a video. | ||
After that step count is reached, the step count is reset and the next | ||
episode will be recorded in full. Empty or negative values means no | ||
video is saved. | ||
query_schedule: one of ("constant", "hyperbolic", "inverse_quadratic"). | ||
A function indicating how the total number of preference queries should | ||
be allocated to each iteration. "hyperbolic" and "inverse_quadratic" | ||
|
@@ -149,14 +157,24 @@ def train_preference_comparisons( | |
ValueError: Inconsistency between config and deserialized policy normalization. | ||
""" | ||
custom_logger, log_dir = common.setup_logging() | ||
checkpoint_dir = log_dir / "checkpoints" | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
rng = common.make_rng() | ||
|
||
with common.make_venv() as venv: | ||
post_wrappers = common.setup_video_saving( | ||
base_dir=checkpoint_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None is the default so you could probably omit this line? |
||
) | ||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
reward_net = reward.make_reward_net(venv) | ||
relabel_reward_fn = functools.partial( | ||
reward_net.predict_processed, | ||
update_stats=False, | ||
) | ||
|
||
if agent_path is None: | ||
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) | ||
else: | ||
|
@@ -250,7 +268,7 @@ def save_callback(iteration_num): | |
if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: | ||
save_checkpoint( | ||
trainer=main_trainer, | ||
save_path=log_dir / "checkpoints" / f"{iteration_num:04d}", | ||
save_path=checkpoint_dir / f"{iteration_num:04d}", | ||
allow_save_policy=bool(trajectory_path is None), | ||
) | ||
|
||
|
@@ -272,7 +290,7 @@ def save_callback(iteration_num): | |
if checkpoint_interval >= 0: | ||
save_checkpoint( | ||
trainer=main_trainer, | ||
save_path=log_dir / "checkpoints" / "final", | ||
save_path=checkpoint_dir / "final", | ||
allow_save_policy=bool(trajectory_path is None), | ||
) | ||
|
||
|
@@ -287,6 +305,5 @@ def main_console(): | |
train_preference_comparisons_ex.observers.append(observer) | ||
train_preference_comparisons_ex.run_commandline() | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
main_console() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
"""Uses RL to train a policy from scratch, saving rollouts and policy. | ||
|
||
This can be used: | ||
1. To train a policy on a ground-truth reward function, as a source of | ||
synthetic "expert" demonstrations to train IRL or imitation learning | ||
|
@@ -17,6 +16,7 @@ | |
from stable_baselines3.common import callbacks | ||
from stable_baselines3.common.vec_env import VecNormalize | ||
|
||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.data import rollout, types, wrappers | ||
from imitation.policies import serialize | ||
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper | ||
|
@@ -38,18 +38,16 @@ def train_rl( | |
rollout_save_n_timesteps: Optional[int], | ||
rollout_save_n_episodes: Optional[int], | ||
policy_save_interval: int, | ||
video_save_interval: int, | ||
policy_save_final: bool, | ||
agent_path: Optional[str], | ||
) -> Mapping[str, float]: | ||
"""Trains an expert policy from scratch and saves the rollouts and policy. | ||
|
||
Checkpoints: | ||
At applicable training steps `step` (where step is either an integer or | ||
"final"): | ||
|
||
- Policies are saved to `{log_dir}/policies/{step}/`. | ||
- Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep this whitespace between sections to be compliant with Google docstring style which we've adopted in this project: https://google.github.io/styleguide/pyguide.html#doc-function-raises Removing whitespace before the list above seems fine though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed now |
||
Args: | ||
total_timesteps: Number of training timesteps in `model.learn()`. | ||
normalize_reward: Applies normalization and clipping to the reward function by | ||
|
@@ -79,10 +77,13 @@ def train_rl( | |
policy_save_interval: The number of training updates between in between | ||
intermediate rollout saves. If the argument is nonpositive, then | ||
don't save intermediate updates. | ||
video_save_interval: The number of steps to take before saving a video. | ||
After that step count is reached, the step count is reset and the next | ||
episode will be recorded in full. Empty or negative values means no | ||
video is saved. | ||
policy_save_final: If True, then save the policy right after training is | ||
finished. | ||
agent_path: Path to load warm-started agent. | ||
|
||
Returns: | ||
The return value of `rollout_stats()` using the final policy. | ||
""" | ||
|
@@ -94,6 +95,13 @@ def train_rl( | |
policy_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] | ||
|
||
post_wrappers = common.setup_video_saving( | ||
base_dir=log_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=post_wrappers | ||
) | ||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
callback_objs = [] | ||
if reward_type is not None: | ||
|
@@ -164,4 +172,4 @@ def main_console(): | |
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
main_console() | ||
main_console() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence[Callable[[gym.Env, int], gym.Env]]
is quite long and repeated -- maybe introduce a type alias for it?