-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add video saving and uploading support to train_*
scripts
#524
Changes from all commits
00592b2
0f8c7d7
904430e
e8ab769
1bc1946
b5daea6
0d10793
9b9ea3d
5ca705f
9563a27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
"""Common configuration elements for training imitation algorithms.""" | ||
|
||
import logging | ||
from typing import Any, Mapping, Union | ||
from typing import Any, Mapping, Optional, Union | ||
|
||
import sacred | ||
import stable_baselines3.common.logger as sb_logger | ||
from stable_baselines3.common import base_class, policies, torch_layers, vec_env | ||
|
||
import imitation.util.networks | ||
from imitation.data import rollout | ||
from imitation.policies import base | ||
from imitation.util import video_wrapper | ||
|
||
train_ingredient = sacred.Ingredient("train") | ||
logger = logging.getLogger(__name__) | ||
|
@@ -23,6 +25,10 @@ def config(): | |
# Evaluation | ||
n_episodes_eval = 50 # Num of episodes for final mean ground truth return | ||
|
||
# Visualization | ||
yawen-d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
videos = False # save video files | ||
video_kwargs = {} # arguments to VideoWrapper | ||
|
||
locals() # quieten flake8 | ||
|
||
|
||
|
@@ -111,6 +117,26 @@ def eval_policy( | |
return rollout.rollout_stats(trajs) | ||
|
||
|
||
@train_ingredient.capture | ||
def save_video( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you call this function it self-documents as if the video were always saved. (but a flag indicating whether this should happen is magically injected through a decorator). I don't have an immediately better alternative, but perhaps a more explanatory function name could help. |
||
videos: bool, | ||
video_kwargs: Mapping[str, Any], | ||
output_dir: str, | ||
policy: policies.BasePolicy, | ||
eval_venv: vec_env.VecEnv, | ||
logger: Optional[sb_logger.Logger] = None, | ||
) -> None: | ||
"""Save video of imitation policy evaluation.""" | ||
if videos: | ||
video_wrapper.record_and_save_video( | ||
output_dir=output_dir, | ||
policy=policy, | ||
eval_venv=eval_venv, | ||
video_kwargs=video_kwargs, | ||
logger=logger, | ||
) | ||
|
||
|
||
@train_ingredient.capture | ||
def suppress_sacred_error(policy_kwargs: Mapping[str, Any]): | ||
"""No-op so Sacred recognizes `policy_kwargs` is used (in `rl` and elsewhere).""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ def train_defaults(): | |
fragment_length = 100 # timesteps per fragment used for comparisons | ||
total_timesteps = int(1e6) # total number of environment timesteps | ||
total_comparisons = 5000 # total number of comparisons to elicit | ||
num_iterations = 5 # Arbitrary, should be tuned for the task | ||
num_iterations = 50 # Arbitrary, should be tuned for the task | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apologies if this has been discussed, but why are you doing this? |
||
comparison_queue_size = None | ||
# factor by which to oversample transitions before creating fragments | ||
transition_oversampling = 1 | ||
|
@@ -39,6 +39,7 @@ def train_defaults(): | |
preference_model_kwargs = {} | ||
reward_trainer_kwargs = { | ||
"epochs": 3, | ||
"weight_decay": 0.0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have to remember changing this as I have a PR that replaces weight decay with a general regularization API (#481). @AdamGleave what do you think, should we merge my PR or this one first? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably best to merge your PR first, though really depends which one is ready earlier. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #481 is ready and passing all the tests AFAIK. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
save_preferences = False # save preference dataset at the end? | ||
agent_path = None # path to a (partially) trained agent to load at the beginning | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import sacred.commands | ||
import torch as th | ||
from sacred.observers import FileStorageObserver | ||
from stable_baselines3.common import vec_env | ||
|
||
from imitation.algorithms.adversarial import airl as airl_algo | ||
from imitation.algorithms.adversarial import common | ||
|
@@ -22,16 +23,26 @@ | |
logger = logging.getLogger("imitation.scripts.train_adversarial") | ||
|
||
|
||
def save(trainer, save_path): | ||
def save_checkpoint( | ||
trainer: common.AdversarialTrainer, | ||
log_dir: str, | ||
eval_venv: vec_env.VecEnv, | ||
round_str: str, | ||
) -> None: | ||
"""Save discriminator and generator.""" | ||
save_path = os.path.join(log_dir, "checkpoints", round_str) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a PR for replacing os.path with pathlib in most places, but might as well keep it consistent for now until that's merged. |
||
# We implement this here and not in Trainer since we do not want to actually | ||
# serialize the whole Trainer (including e.g. expert demonstrations). | ||
os.makedirs(save_path, exist_ok=True) | ||
th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt")) | ||
th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt")) | ||
serialize.save_stable_model( | ||
os.path.join(save_path, "gen_policy"), | ||
trainer.gen_algo, | ||
policy_path = os.path.join(save_path, "gen_policy") | ||
serialize.save_stable_model(policy_path, trainer.gen_algo) | ||
train.save_video( | ||
output_dir=policy_path, | ||
policy=trainer.gen_algo.policy, | ||
eval_venv=eval_venv, | ||
logger=trainer.logger, | ||
) | ||
|
||
|
||
|
@@ -67,7 +78,6 @@ def dummy_config(): | |
@train_adversarial_ex.capture | ||
def train_adversarial( | ||
_run, | ||
_seed: int, | ||
show_config: bool, | ||
algo_cls: Type[common.AdversarialTrainer], | ||
algorithm_kwargs: Mapping[str, Any], | ||
|
@@ -84,7 +94,6 @@ def train_adversarial( | |
- Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`. | ||
|
||
Args: | ||
_seed: Random seed. | ||
show_config: Print the merged config before starting training. This is | ||
analogous to the print_config command, but will show config after | ||
rather than before merging `algorithm_specific` arguments. | ||
|
@@ -117,6 +126,7 @@ def train_adversarial( | |
expert_trajs = demonstrations.get_expert_trajectories() | ||
|
||
with common_config.make_venv() as venv: | ||
|
||
reward_net = reward.make_reward_net(venv) | ||
relabel_reward_fn = functools.partial( | ||
reward_net.predict_processed, | ||
|
@@ -150,16 +160,20 @@ def train_adversarial( | |
**algorithm_kwargs, | ||
) | ||
|
||
def callback(round_num): | ||
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: | ||
save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}")) | ||
with common_config.make_venv(num_vec=1, log_dir=None) as eval_venv: | ||
|
||
trainer.train(total_timesteps, callback) | ||
imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) | ||
def callback(round_num): | ||
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: | ||
round_str = f"{round_num:05d}" | ||
save_checkpoint(trainer, log_dir, eval_venv, round_str=round_str) | ||
|
||
trainer.train(total_timesteps, callback) | ||
|
||
# Save final artifacts. | ||
if checkpoint_interval >= 0: | ||
save(trainer, os.path.join(log_dir, "checkpoints", "final")) | ||
# Save final artifacts. | ||
if checkpoint_interval >= 0: | ||
save_checkpoint(trainer, log_dir, eval_venv, round_str="final") | ||
|
||
imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) | ||
|
||
return { | ||
"imit_stats": imit_stats, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should add in the docstring what the callback type signature represents.