Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add video saving and uploading support to train_* scripts #524

Closed
wants to merge 10 commits into from
8 changes: 5 additions & 3 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def __init__(
# for keeping track of the global iteration, in case train() is called
# multiple times
self._iteration = 0
self.trajectory_generator_num_steps = 0

self.model = reward_model

Expand Down Expand Up @@ -1442,7 +1443,7 @@ def train(
self,
total_timesteps: int,
total_comparisons: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[Callable[[int, int], None]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should add in the docstring what the callback type signature represents.

) -> Mapping[str, Any]:
"""Train the reward model and the policy if applicable.

Expand Down Expand Up @@ -1526,14 +1527,15 @@ def train(
with self.logger.accumulate_means("agent"):
self.logger.log(f"Training agent for {num_steps} timesteps")
self.trajectory_generator.train(steps=num_steps)
self.trajectory_generator_num_steps += num_steps

self.logger.dump(self._iteration)

########################
# Additional Callbacks #
########################
if callback:
callback(self._iteration)
self._iteration += 1
if callback:
callback(self._iteration, self.trajectory_generator_num_steps)

return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy}
4 changes: 2 additions & 2 deletions src/imitation/scripts/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import logging
import os
from typing import Any, Mapping, Sequence, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import sacred
from stable_baselines3.common import vec_env
Expand Down Expand Up @@ -134,7 +134,7 @@ def make_venv(
env_name: str,
num_vec: int,
parallel: bool,
log_dir: str,
log_dir: Optional[str],
max_episode_steps: int,
env_make_kwargs: Mapping[str, Any],
**kwargs,
Expand Down
28 changes: 27 additions & 1 deletion src/imitation/scripts/common/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Common configuration elements for training imitation algorithms."""

import logging
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union

import sacred
import stable_baselines3.common.logger as sb_logger
from stable_baselines3.common import base_class, policies, torch_layers, vec_env

import imitation.util.networks
from imitation.data import rollout
from imitation.policies import base
from imitation.util import video_wrapper

train_ingredient = sacred.Ingredient("train")
logger = logging.getLogger(__name__)
Expand All @@ -23,6 +25,10 @@ def config():
# Evaluation
n_episodes_eval = 50 # Num of episodes for final mean ground truth return

# Visualization
yawen-d marked this conversation as resolved.
Show resolved Hide resolved
videos = False # save video files
video_kwargs = {} # arguments to VideoWrapper

locals() # quieten flake8


Expand Down Expand Up @@ -111,6 +117,26 @@ def eval_policy(
return rollout.rollout_stats(trajs)


@train_ingredient.capture
def save_video(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you call this function it self-documents as if the video were always saved. (but a flag indicating whether this should happen is magically injected through a decorator). I don't have an immediately better alternative, but perhaps a more explanatory function name could help.

videos: bool,
video_kwargs: Mapping[str, Any],
output_dir: str,
policy: policies.BasePolicy,
eval_venv: vec_env.VecEnv,
logger: Optional[sb_logger.Logger] = None,
) -> None:
"""Save video of imitation policy evaluation."""
if videos:
video_wrapper.record_and_save_video(
output_dir=output_dir,
policy=policy,
eval_venv=eval_venv,
video_kwargs=video_kwargs,
logger=logger,
)


@train_ingredient.capture
def suppress_sacred_error(policy_kwargs: Mapping[str, Any]):
"""No-op so Sacred recognizes `policy_kwargs` is used (in `rl` and elsewhere)."""
3 changes: 2 additions & 1 deletion src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_defaults():
fragment_length = 100 # timesteps per fragment used for comparisons
total_timesteps = int(1e6) # total number of environment timesteps
total_comparisons = 5000 # total number of comparisons to elicit
num_iterations = 5 # Arbitrary, should be tuned for the task
num_iterations = 50 # Arbitrary, should be tuned for the task
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if this has been discussed, but why are you doing this?

comparison_queue_size = None
# factor by which to oversample transitions before creating fragments
transition_oversampling = 1
Expand All @@ -39,6 +39,7 @@ def train_defaults():
preference_model_kwargs = {}
reward_trainer_kwargs = {
"epochs": 3,
"weight_decay": 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to remember changing this as I have a PR that replaces weight decay with a general regularization API (#481). @AdamGleave what do you think, should we merge my PR or this one first?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to merge your PR first, though really depends which one is ready earlier.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#481 is ready and passing all the tests AFAIK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for proposing #481. #481 seems to be the feature wanted. I'll make changes accordingly.

}
save_preferences = False # save preference dataset at the end?
agent_path = None # path to a (partially) trained agent to load at the beginning
Expand Down
42 changes: 28 additions & 14 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sacred.commands
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import vec_env

from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
Expand All @@ -22,16 +23,26 @@
logger = logging.getLogger("imitation.scripts.train_adversarial")


def save(trainer, save_path):
def save_checkpoint(
trainer: common.AdversarialTrainer,
log_dir: str,
eval_venv: vec_env.VecEnv,
round_str: str,
) -> None:
"""Save discriminator and generator."""
save_path = os.path.join(log_dir, "checkpoints", round_str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR for replacing os.path with pathlib in most places, but might as well keep it consistent for now until that's merged.

# We implement this here and not in Trainer since we do not want to actually
# serialize the whole Trainer (including e.g. expert demonstrations).
os.makedirs(save_path, exist_ok=True)
th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt"))
th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt"))
serialize.save_stable_model(
os.path.join(save_path, "gen_policy"),
trainer.gen_algo,
policy_path = os.path.join(save_path, "gen_policy")
serialize.save_stable_model(policy_path, trainer.gen_algo)
train.save_video(
output_dir=policy_path,
policy=trainer.gen_algo.policy,
eval_venv=eval_venv,
logger=trainer.logger,
)


Expand Down Expand Up @@ -67,7 +78,6 @@ def dummy_config():
@train_adversarial_ex.capture
def train_adversarial(
_run,
_seed: int,
show_config: bool,
algo_cls: Type[common.AdversarialTrainer],
algorithm_kwargs: Mapping[str, Any],
Expand All @@ -84,7 +94,6 @@ def train_adversarial(
- Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`.

Args:
_seed: Random seed.
show_config: Print the merged config before starting training. This is
analogous to the print_config command, but will show config after
rather than before merging `algorithm_specific` arguments.
Expand Down Expand Up @@ -117,6 +126,7 @@ def train_adversarial(
expert_trajs = demonstrations.get_expert_trajectories()

with common_config.make_venv() as venv:

reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down Expand Up @@ -150,16 +160,20 @@ def train_adversarial(
**algorithm_kwargs,
)

def callback(round_num):
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}"))
with common_config.make_venv(num_vec=1, log_dir=None) as eval_venv:

trainer.train(total_timesteps, callback)
imit_stats = train.eval_policy(trainer.policy, trainer.venv_train)
def callback(round_num):
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
round_str = f"{round_num:05d}"
save_checkpoint(trainer, log_dir, eval_venv, round_str=round_str)

trainer.train(total_timesteps, callback)

# Save final artifacts.
if checkpoint_interval >= 0:
save(trainer, os.path.join(log_dir, "checkpoints", "final"))
# Save final artifacts.
if checkpoint_interval >= 0:
save_checkpoint(trainer, log_dir, eval_venv, round_str="final")

imit_stats = train.eval_policy(trainer.policy, trainer.venv_train)

return {
"imit_stats": imit_stats,
Expand Down
9 changes: 8 additions & 1 deletion src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def make_policy(

@train_imitation_ex.capture
def train_imitation(
_run,
bc_kwargs: Mapping[str, Any],
bc_train_kwargs: Mapping[str, Any],
dagger: Mapping[str, Any],
Expand Down Expand Up @@ -132,6 +131,14 @@ def train_imitation(

imit_stats = train.eval_policy(imit_policy, venv)

with common.make_venv(num_vec=1, log_dir=None) as eval_venv:
train.save_video(
output_dir=log_dir,
policy=imit_policy,
eval_venv=eval_venv,
logger=custom_logger,
)

return {
"imit_stats": imit_stats,
"expert_stats": rollout.rollout_stats(
Expand Down
85 changes: 46 additions & 39 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import functools
import os
import os.path as osp
from typing import Any, Mapping, Optional, Type, Union

import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases
from stable_baselines3.common import type_aliases, vec_env

from imitation.algorithms import preference_comparisons
from imitation.data import types
Expand All @@ -23,31 +24,33 @@
)


def save_model(
agent_trainer: preference_comparisons.AgentTrainer,
save_path: str,
):
"""Save the model as model.pkl."""
serialize.save_stable_model(
output_dir=os.path.join(save_path, "policy"),
model=agent_trainer.algorithm,
)


def save_checkpoint(
trainer: preference_comparisons.PreferenceComparisons,
save_path: str,
log_dir: str,
allow_save_policy: Optional[bool],
):
eval_venv: vec_env.VecEnv,
round_str: str,
) -> None:
"""Save reward model and optionally policy."""
save_path = osp.join(log_dir, "checkpoints", round_str)
os.makedirs(save_path, exist_ok=True)
th.save(trainer.model, os.path.join(save_path, "reward_net.pt"))
th.save(trainer.model, osp.join(save_path, "reward_net.pt"))
if allow_save_policy:
# Note: We should only save the model as model.pkl if `trajectory_generator`
# contains one. Specifically we check if the `trajectory_generator` contains an
# `algorithm` attribute.
assert hasattr(trainer.trajectory_generator, "algorithm")
save_model(trainer.trajectory_generator, save_path)
policy_dir = osp.join(save_path, "policy")
serialize.save_stable_model(
output_dir=policy_dir,
model=trainer.trajectory_generator.algorithm,
)
train.save_video(
output_dir=policy_dir,
policy=trainer.trajectory_generator.algorithm.policy,
eval_venv=eval_venv,
logger=trainer.logger,
)
else:
trainer.logger.warn(
"trainer.trajectory_generator doesn't contain a policy to save.",
Expand Down Expand Up @@ -242,46 +245,50 @@ def train_preference_comparisons(
query_schedule=query_schedule,
)

def save_callback(iteration_num):
if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0:
# Create an eval_venv for policy evaluation and maybe visualization.
with common.make_venv(num_vec=1, log_dir=None) as eval_venv:

def save_callback(iter_num, traj_gen_num_steps):
if checkpoint_interval > 0 and iter_num % checkpoint_interval == 0:
round_str = f"iter_{iter_num:04d}_step_{traj_gen_num_steps:08d}"
save_checkpoint(
trainer=main_trainer,
log_dir=log_dir,
allow_save_policy=bool(trajectory_path is None),
eval_venv=eval_venv,
round_str=round_str,
)

results = main_trainer.train(
total_timesteps,
total_comparisons,
callback=save_callback,
)

# Save final artifacts.
if checkpoint_interval >= 0:
save_checkpoint(
trainer=main_trainer,
save_path=os.path.join(
log_dir,
"checkpoints",
f"{iteration_num:04d}",
),
log_dir=log_dir,
allow_save_policy=bool(trajectory_path is None),
eval_venv=eval_venv,
round_str="final",
)

results = main_trainer.train(
total_timesteps,
total_comparisons,
callback=save_callback,
)

# Storing and evaluating policy only useful if we generated trajectory data
if bool(trajectory_path is None):
results = dict(results)
results["rollout"] = train.eval_policy(agent, venv)

if save_preferences:
main_trainer.dataset.save(os.path.join(log_dir, "preferences.pkl"))

# Save final artifacts.
if checkpoint_interval >= 0:
save_checkpoint(
trainer=main_trainer,
save_path=os.path.join(log_dir, "checkpoints", "final"),
allow_save_policy=bool(trajectory_path is None),
)
main_trainer.dataset.save(osp.join(log_dir, "preferences.pkl"))

return results


def main_console():
observer = FileStorageObserver(
os.path.join("output", "sacred", "train_preference_comparisons"),
osp.join("output", "sacred", "train_preference_comparisons"),
)
train_preference_comparisons_ex.observers.append(observer)
train_preference_comparisons_ex.run_commandline()
Expand Down
Loading