Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save videos during training #597

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
49 changes: 48 additions & 1 deletion src/imitation/scripts/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import contextlib
import logging
import pathlib
from typing import Any, Generator, Mapping, Sequence, Tuple, Union
from typing import Any, Generator, Mapping, Sequence, Tuple, Union, Optional, Callable

import gym
import numpy as np
import sacred
from stable_baselines3.common import vec_env
Expand All @@ -14,6 +15,7 @@
from imitation.util import logger as imit_logger
from imitation.util import sacred as sacred_util
from imitation.util import util
import imitation.util.video_wrapper as video_wrapper

common_ingredient = sacred.Ingredient("common", ingredients=[wb.wandb_ingredient])
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,6 +134,51 @@ def setup_logging(
)
return custom_logger, log_dir

@common_ingredient.capture
def setup_video_saving(
_run,
base_dir: pathlib.Path,
video_save_interval: int,
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence[Callable[[gym.Env, int], gym.Env]] is quite long and repeated -- maybe introduce a type alias for it?

) -> Optional[Sequence[Callable[[gym.Env, int], gym.Env]]]:
"""Adds video saving to the existing post wrappers.

Args:
base_dir: the videos will be saved to a videos subdirectory under
this base directory
video_save_interval: the video wrapper will save a video of the next
episode that begins after every Nth step. So if
video_save_interval=100 and each episode has 30 steps, it will record
the 4th episode(first to start after step_count=100) and then the 7th
episode (first to start after step_count=200).
post_wrappers: If specified, iteratively wraps each environment with each
of the wrappers specified in the sequence. The argument should be a
Callable accepting two arguments, the Env to be wrapped and the
environment index, and returning the wrapped Env.

Returns:
A new post_wrapper list with the video saving wrapper appended to the
existing list. If the existing post wrapper was null, it will create a
new list with just the video wrapper. If the video_save_interval is <=0,
it will just return the inputted post_wrapper
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

if video_save_interval > 0:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

video_dir = base_dir / "videos"
video_dir.mkdir(parents=True, exist_ok=True)

post_wrappers_copy = [wrapper for wrapper in post_wrappers] \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
post_wrappers_copy = [wrapper for wrapper in post_wrappers] \
post_wrappers_copy = list(post_wrappers) if post_wrappers != None else []

if post_wrappers != None else []
post_wrappers_copy.append(
video_wrapper.video_wrapper_factory(video_dir, video_save_interval)
)

return post_wrappers_copy

return post_wrappers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally have two lines between functions not four (though I guess more is OK if you want to emphasize separation between them)



@contextlib.contextmanager
@common_ingredient.capture
Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def make_rl_algo(
)
else:
raise TypeError(f"Unsupported RL algorithm '{rl_cls}'")

rl_algo = rl_cls(
policy=train["policy_cls"],
# Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass.
Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/config/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def defaults():
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config

checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables)
video_save_interval = 0 # Number of steps before saving video (<=0 disables)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note you could put this in configs/common/common.py. However, eval_policy has its own video logic that is not easy to unify (it saves every episode for every environment, not just the first, which makes sense given its purpose) so it's not truly common so seems OK to leave it per-algorithm.

agent_path = None # Path to load agent from, optional.


Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/config/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def config():
total_timesteps=1e5,
)
agent_path = None # Path to load agent from, optional.
video_save_interval = 0 # <=0 means no saving


@train_imitation_ex.named_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def train_defaults():
allow_variable_horizon = False

checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
video_save_interval = 0 # Number of steps before saving video (<=0 disables)
query_schedule = "hyperbolic"


Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/config/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def train_rl_defaults():
rollout_save_n_episodes = None # Num episodes saved per file, optional.

policy_save_interval = 10000 # Num timesteps between saves (<=0 disables)
video_save_interval = 0 # Number of steps before saving video (<=0 disables)
policy_save_final = True # If True, save after training is finished.

agent_path = None # Path to load agent from, optional.
Expand Down
21 changes: 18 additions & 3 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch as th
from sacred.observers import FileStorageObserver

import imitation.util.video_wrapper as video_wrapper
from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
from imitation.algorithms.adversarial import gail as gail_algo
Expand Down Expand Up @@ -71,6 +72,7 @@ def train_adversarial(
algorithm_kwargs: Mapping[str, Any],
total_timesteps: int,
checkpoint_interval: int,
video_save_interval: int,
agent_path: Optional[str],
) -> Mapping[str, Mapping[str, float]]:
"""Train an adversarial-network-based imitation learning algorithm.
Expand All @@ -93,6 +95,10 @@ def train_adversarial(
`checkpoint_interval` rounds and after training is complete. If 0,
then only save weights after training is complete. If <0, then don't
save weights at all.
video_save_interval: The number of steps to take before saving a video.
After that step count is reached, the step count is reset and the next
episode will be recorded in full. Empty or negative values means no
video is saved.
agent_path: Path to a directory containing a pre-trained agent. If
provided, then the agent will be initialized using this stored policy
(warm start). If not provided, then the agent will be initialized using
Expand All @@ -111,9 +117,18 @@ def train_adversarial(
sacred.commands.print_config(_run)

custom_logger, log_dir = common_config.setup_logging()
checkpoint_dir = log_dir / "checkpoints"
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_dir.mkdir(parents=True, exist_ok=True)

expert_trajs = demonstrations.get_expert_trajectories()

with common_config.make_venv() as venv:
post_wrappers = common_config.setup_video_saving(
base_dir=checkpoint_dir,
video_save_interval=video_save_interval,
post_wrappers=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default so you could probably omit this line?

)

with common_config.make_venv(post_wrappers=post_wrappers) as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down Expand Up @@ -149,14 +164,14 @@ def train_adversarial(

def callback(round_num: int, /) -> None:
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
save(trainer, log_dir / "checkpoints" / f"{round_num:05d}")
save(trainer, checkpoint_dir / f"{round_num:05d}")

trainer.train(total_timesteps, callback)
imit_stats = train.eval_policy(trainer.policy, trainer.venv_train)

# Save final artifacts.
if checkpoint_interval >= 0:
save(trainer, log_dir / "checkpoints" / "final")
save(trainer, checkpoint_dir / "final")

return {
"imit_stats": imit_stats,
Expand Down
13 changes: 12 additions & 1 deletion src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def train_imitation(
dagger: Mapping[str, Any],
use_dagger: bool,
agent_path: Optional[str],
video_save_interval: int
) -> Mapping[str, Mapping[str, float]]:
"""Runs DAgger (if `use_dagger`) or BC (otherwise) training.

Expand All @@ -82,14 +83,24 @@ def train_imitation(
agent_path: Path to serialized policy. If provided, then load the
policy from this path. Otherwise, make a new policy.
Specify only if policy_cls and policy_kwargs are not specified.
video_save_interval: The number of steps to take before saving a video.
After that step count is reached, the step count is reset and the next
episode will be recorded in full. Empty or negative values means no
video is saved.

Returns:
Statistics for rollouts from the trained policy and demonstration data.
"""
rng = common.make_rng()
custom_logger, log_dir = common.setup_logging()

with common.make_venv() as venv:
post_wrappers = common.setup_video_saving(
base_dir=log_dir,
video_save_interval=video_save_interval,
post_wrappers=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default so you could probably omit this line?

)

with common.make_venv(post_wrappers=post_wrappers) as venv:
imit_policy = make_policy(venv, agent_path=agent_path)

expert_trajs: Optional[Sequence[types.Trajectory]] = None
Expand Down
25 changes: 21 additions & 4 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases

import gym
from imitation.algorithms import preference_comparisons
from imitation.data import types
from imitation.policies import serialize
Expand All @@ -21,6 +22,8 @@
from imitation.scripts.config.train_preference_comparisons import (
train_preference_comparisons_ex,
)
import imitation.util.video_wrapper as video_wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- fixed




def save_model(
Expand Down Expand Up @@ -81,6 +84,7 @@ def train_preference_comparisons(
fragmenter_kwargs: Mapping[str, Any],
allow_variable_horizon: bool,
checkpoint_interval: int,
video_save_interval: int,
query_schedule: Union[str, type_aliases.Schedule],
) -> Mapping[str, Any]:
"""Train a reward model using preference comparisons.
Expand Down Expand Up @@ -136,6 +140,10 @@ def train_preference_comparisons(
trajectory_generator contains a policy) every `checkpoint_interval`
iterations and after training is complete. If 0, then only save weights
after training is complete. If <0, then don't save weights at all.
video_save_interval: The number of steps to take before saving a video.
After that step count is reached, the step count is reset and the next
episode will be recorded in full. Empty or negative values means no
video is saved.
query_schedule: one of ("constant", "hyperbolic", "inverse_quadratic").
A function indicating how the total number of preference queries should
be allocated to each iteration. "hyperbolic" and "inverse_quadratic"
Expand All @@ -149,14 +157,24 @@ def train_preference_comparisons(
ValueError: Inconsistency between config and deserialized policy normalization.
"""
custom_logger, log_dir = common.setup_logging()
checkpoint_dir = log_dir / "checkpoints"
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_dir.mkdir(parents=True, exist_ok=True)

rng = common.make_rng()

with common.make_venv() as venv:
post_wrappers = common.setup_video_saving(
base_dir=checkpoint_dir,
video_save_interval=video_save_interval,
post_wrappers=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default so you could probably omit this line?

)

with common.make_venv(post_wrappers=post_wrappers) as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
update_stats=False,
)

if agent_path is None:
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn)
else:
Expand Down Expand Up @@ -250,7 +268,7 @@ def save_callback(iteration_num):
if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0:
save_checkpoint(
trainer=main_trainer,
save_path=log_dir / "checkpoints" / f"{iteration_num:04d}",
save_path=checkpoint_dir / f"{iteration_num:04d}",
allow_save_policy=bool(trajectory_path is None),
)

Expand All @@ -272,7 +290,7 @@ def save_callback(iteration_num):
if checkpoint_interval >= 0:
save_checkpoint(
trainer=main_trainer,
save_path=log_dir / "checkpoints" / "final",
save_path=checkpoint_dir / "final",
allow_save_policy=bool(trajectory_path is None),
)

Expand All @@ -287,6 +305,5 @@ def main_console():
train_preference_comparisons_ex.observers.append(observer)
train_preference_comparisons_ex.run_commandline()


if __name__ == "__main__": # pragma: no cover
main_console()
20 changes: 14 additions & 6 deletions src/imitation/scripts/train_rl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Uses RL to train a policy from scratch, saving rollouts and policy.

This can be used:
1. To train a policy on a ground-truth reward function, as a source of
synthetic "expert" demonstrations to train IRL or imitation learning
Expand All @@ -17,6 +16,7 @@
from stable_baselines3.common import callbacks
from stable_baselines3.common.vec_env import VecNormalize

import imitation.util.video_wrapper as video_wrapper
from imitation.data import rollout, types, wrappers
from imitation.policies import serialize
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper
Expand All @@ -38,18 +38,16 @@ def train_rl(
rollout_save_n_timesteps: Optional[int],
rollout_save_n_episodes: Optional[int],
policy_save_interval: int,
video_save_interval: int,
policy_save_final: bool,
agent_path: Optional[str],
) -> Mapping[str, float]:
"""Trains an expert policy from scratch and saves the rollouts and policy.

Checkpoints:
At applicable training steps `step` (where step is either an integer or
"final"):

- Policies are saved to `{log_dir}/policies/{step}/`.
- Rollouts are saved to `{log_dir}/rollouts/{step}.npz`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep this whitespace between sections to be compliant with Google docstring style which we've adopted in this project: https://google.github.io/styleguide/pyguide.html#doc-function-raises

Removing whitespace before the list above seems fine though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now

Args:
total_timesteps: Number of training timesteps in `model.learn()`.
normalize_reward: Applies normalization and clipping to the reward function by
Expand Down Expand Up @@ -79,10 +77,13 @@ def train_rl(
policy_save_interval: The number of training updates between in between
intermediate rollout saves. If the argument is nonpositive, then
don't save intermediate updates.
video_save_interval: The number of steps to take before saving a video.
After that step count is reached, the step count is reset and the next
episode will be recorded in full. Empty or negative values means no
video is saved.
policy_save_final: If True, then save the policy right after training is
finished.
agent_path: Path to load warm-started agent.

Returns:
The return value of `rollout_stats()` using the final policy.
"""
Expand All @@ -94,6 +95,13 @@ def train_rl(
policy_dir.mkdir(parents=True, exist_ok=True)

post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]

post_wrappers = common.setup_video_saving(
base_dir=log_dir,
video_save_interval=video_save_interval,
post_wrappers=post_wrappers
)

with common.make_venv(post_wrappers=post_wrappers) as venv:
callback_objs = []
if reward_type is not None:
Expand Down Expand Up @@ -164,4 +172,4 @@ def main_console():


if __name__ == "__main__": # pragma: no cover
main_console()
main_console()
5 changes: 4 additions & 1 deletion src/imitation/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ def write(
if excluded is not None and "wandb" in excluded:
continue

self.wandb_module.log({key: value}, step=step)
if key != "video":
self.wandb_module.log({key: value}, step=step)
else:
self.wandb_module.log({"video": self.wandb_module.Video(value)})
self.wandb_module.log({}, commit=True)

def close(self) -> None:
Expand Down
Loading