Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save videos during training #597

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

save videos during training #597

wants to merge 13 commits into from

Conversation

samuelarnesen
Copy link
Contributor

@samuelarnesen samuelarnesen commented Oct 28, 2022

Description

This addresses Issue #523 to automatically save videos during training time. This builds off of the following, earlier PR.

Known Limitations:
(1) Will not necessarily save a video at the end of the training run - it just saves a video at the first episode after each checkpoint.
(2) Saves videos during training episodes (and not in a separate, evaluation environment)

Testing

Added tests to test_scripts.

@AdamGleave
Copy link
Member

A lot of the test failures on Mac seem to be down to ffmpeg not being installed. I think we removed that in #539 as it was making Mac OS build times very high, although an alternative that might be worth revisiting is installing ffmpeg from a static executable (it was compiling it under brew that was slow). Alternatively we can just skip the video tests on Mac, as we do the others.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach seems good at a high-level and much simpler than the previous one.

Key issue is that unfortunately I think checkpoint_interval does not have unit of timesteps or episodes but rather algorithm iterations, and this is additionally different from algorithm to algorithm. So we either need to do conversion, trigger the video recording in some other way, or specify video recording frequency separately in the config.

I'd also suggest having a way to disable video recording and probably leave it off by default as it introduces overhead and extra binary dependencies.

The other higher-level change is I suspect there's a way to cut down on code duplication in the script by putting a lot of this logic into the common Sacred ingredient. But I've not figured out the details on this so this suggestion may end up being off-base.

Other comments were fairly minor.

@@ -177,4 +177,4 @@ def make_venv(
try:
yield venv
finally:
venv.close()
venv.close()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change (removing newline) intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope -- fixing now

@@ -10,6 +10,8 @@
from imitation.data import rollout
from imitation.policies import base
from imitation.scripts.common import common
from imitation.util import video_wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now!

src/imitation/scripts/train_adversarial.py Show resolved Hide resolved
@@ -21,6 +23,8 @@
from imitation.scripts.config.train_preference_comparisons import (
train_preference_comparisons_ex,
)
import imitation.util.video_wrapper as video_wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- fixed

Saves a preference comparisons ensemble reward, then loads it for transfer learning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to remove this?


def _check_video_exists(log_dir, algo):
video_dir = VIDEO_PATH_DICT[algo](log_dir)
assert os.path.exists(video_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If video_dir is a Pathlib.path I think just video_dir.exists() works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated!

def _check_video_exists(log_dir, algo):
video_dir = VIDEO_PATH_DICT[algo](log_dir)
assert os.path.exists(video_dir)
assert VIDEO_FILE_PATH in os.listdir(video_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd guess there's a pathlib version of this too but not sure.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be updated now

assert os.path.exists(video_dir)
assert VIDEO_FILE_PATH in os.listdir(video_dir)

def test_train_rl_video_saving(tmpdir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of duplication here with test_train_rl_main, perhaps we can combine them somehow? This comment also applies to some extent to test_train_adversarial_* and test_train_preference_comparisons_* below.

@@ -113,6 +113,11 @@ def test_wandb_output_format():
{"_step": 0, "foo": 42, "fizz": 12},
{"_step": 3, "fizz": 21},
]

with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing invalid input error handling is good but it's a bit odd we're not also testing that it does the right thing with a valid input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the basic saving features already exists in a previous test (since manual video saving is already supported) -- I just added this test since it was in the original PR and figured it couldn't hurt.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, I think this is almost there.

Left some comments but most of them are pretty minor stylistic things.

One high-level issue I noticed: I think the video_save_interval only has any effect if single_video is False. Otherwise, I think the video just keeps recording once you start it (though I might be wrong here -- I've not actually tested the bug exists). Unfortunately it's True by default! If I'm right, we should probably do input validation to require that single_video is False whenever video_save_interval != 1, and consider adding a test case to make sure videos are saved at an appropriate interval (could probably just mock this to avoid actually stepping through an environment and saving multiple videos).

There's a lot of lint errors, they seem to mostly be about docstring formatting. Our linter here is a bit obscure, so let me know if you have trouble figuring out what any of them mean. We're expecting docstrings to be in this format: https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html

video_dir = base_dir / "videos"
video_dir.mkdir(parents=True, exist_ok=True)

post_wrappers_copy = [wrapper for wrapper in post_wrappers] \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
post_wrappers_copy = [wrapper for wrapper in post_wrappers] \
post_wrappers_copy = list(post_wrappers) if post_wrappers != None else []

_run,
base_dir: pathlib.Path,
video_save_interval: int,
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence[Callable[[gym.Env, int], gym.Env]] is quite long and repeated -- maybe introduce a type alias for it?

@@ -30,6 +30,7 @@ def defaults():
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config

checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables)
video_save_interval = 0 # Number of steps before saving video (<=0 disables)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note you could put this in configs/common/common.py. However, eval_policy has its own video logic that is not easy to unify (it saves every episode for every environment, not just the first, which makes sense given its purpose) so it's not truly common so seems OK to leave it per-algorithm.

post_wrappers = common_config.setup_video_saving(
base_dir=checkpoint_dir,
video_save_interval=video_save_interval,
post_wrappers=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default so you could probably omit this line?

post_wrappers = common.setup_video_saving(
base_dir=log_dir,
video_save_interval=video_save_interval,
post_wrappers=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the default so you could probably omit this line?

@@ -81,7 +82,7 @@ def reset(self):
def step(self, action):
res = self.env.step(action)
self.step_count += 1
if self.step_count % self.cadence == 0:
if self.step_count % self.video_save_interval == 0:
self.should_record == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't do anything (equality check) -- was this meant to be an assignment? Also self.should_record has type bool not int.


"""Set Sacred capture mode to "sys" because default "fd" option leads to error.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd, most of our functions do not have whitespace before docstring, I am not sure why this would cause a linter error.

Is it this?

tests/scripts/test_scripts.py:1:1: D205 1 blank line required between summary line and description

I think that's complaining there's not a newline following this line, i.e. it's expecting docstrings in format:

A short one sentence summary.

Optionally, a more elaborate description of what this function does.
Some details only for the astute reader.

Args:
    foo: ...

new list with just the video wrapper. If the video_save_interval is <=0,
it will just return the inputted post_wrapper
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

"""

if video_save_interval > 0:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

return post_wrappers_copy

return post_wrappers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally have two lines between functions not four (though I guess more is OK if you want to emphasize separation between them)

@codecov
Copy link

codecov bot commented Oct 31, 2022

Codecov Report

Merging #597 (7714484) into master (d6f8ca2) will increase coverage by 0.04%.
The diff coverage is 98.90%.

@@            Coverage Diff             @@
##           master     #597      +/-   ##
==========================================
+ Coverage   97.49%   97.54%   +0.04%     
==========================================
  Files          85       85              
  Lines        8099     8176      +77     
==========================================
+ Hits         7896     7975      +79     
+ Misses        203      201       -2     
Impacted Files Coverage Δ
src/imitation/scripts/common/rl.py 97.43% <ø> (ø)
src/imitation/scripts/train_adversarial.py 94.87% <83.33%> (+0.20%) ⬆️
src/imitation/scripts/common/common.py 97.82% <100.00%> (+0.29%) ⬆️
src/imitation/scripts/config/train_adversarial.py 71.25% <100.00%> (+0.36%) ⬆️
src/imitation/scripts/config/train_imitation.py 71.66% <100.00%> (+3.87%) ⬆️
...ion/scripts/config/train_preference_comparisons.py 85.52% <100.00%> (+0.19%) ⬆️
src/imitation/scripts/config/train_rl.py 79.22% <100.00%> (+0.27%) ⬆️
src/imitation/scripts/eval_policy.py 100.00% <100.00%> (ø)
src/imitation/scripts/train_imitation.py 95.65% <100.00%> (+0.06%) ⬆️
.../imitation/scripts/train_preference_comparisons.py 97.01% <100.00%> (+0.13%) ⬆️
... and 6 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@ernestum ernestum added this to the Release v1.x milestone May 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants