Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable warm starting adversarial algorithms w/ behavior cloning #602

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/imitation/scripts/config/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def defaults():

checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables)
agent_path = None # Path to load agent from, optional.
warm_start_with_bc = False # default to not warmstarting with behavior cloning
bc_config = None # No config for behavior cloning if not warm-starting
device = "auto" # Device needed only if warmstarting w/ behavior cloning


@train_adversarial_ex.config
Expand Down
29 changes: 23 additions & 6 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import functools
import logging
import pathlib
from typing import Any, Mapping, Optional, Type
from typing import Any, Mapping, Optional, Type, Union

import sacred.commands
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import utils

import imitation.scripts.train_imitation as train_imitation
from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
from imitation.algorithms.adversarial import gail as gail_algo
Expand Down Expand Up @@ -73,6 +75,9 @@ def train_adversarial(
total_timesteps: int,
checkpoint_interval: int,
agent_path: Optional[str],
warm_start_with_bc: bool,
bc_config: Optional[Mapping[str, Any]],
device: Union[str, th.device],
) -> Mapping[str, Mapping[str, float]]:
"""Train an adversarial-network-based imitation learning algorithm.

Expand All @@ -98,6 +103,13 @@ def train_adversarial(
provided, then the agent will be initialized using this stored policy
(warm start). If not provided, then the agent will be initialized using
a random policy.
warm_start_with_bc: boolean indicates whether one should pre-train using
behavior cloning before using one of the adversarial algorithms
bc_config: Only applies if warm_start_with_bc=True. These are the settings
that govern the pre-training w/ behavior cloning. See the documentation
for behavior cloning for all the (optional) individual parameters.
device: Only needed if warm_start_with_bc is true. This is the device that
the training is running on. Defaults to "auto".

Returns:
A dictionary with two keys. "imit_stats" gives the return value of
Expand All @@ -106,11 +118,6 @@ def train_adversarial(
"monitor_return" key). "expert_stats" gives the return value of
`rollout_stats()` on the expert demonstrations.
"""
# This allows to specify total_timesteps and checkpoint_interval in scientific
# notation, which is interpreted as a float by python.
total_timesteps = int(total_timesteps)
checkpoint_interval = int(checkpoint_interval)

if show_config:
# Running `train_adversarial print_config` will show unmerged config.
# So, support showing merged config from `train_adversarial {airl,gail}`.
Expand All @@ -119,6 +126,10 @@ def train_adversarial(
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

previous_policy_path = None
if warm_start_with_bc:
previous_policy_path = train_imitation.warm_start_with_bc(bc_config=bc_config)

with environment.make_venv() as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
Expand All @@ -128,6 +139,12 @@ def train_adversarial(

if agent_path is None:
gen_algo = rl.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn)
if previous_policy_path is not None:
previous_policy = th.load(
previous_policy_path,
map_location=utils.get_device(device),
)
gen_algo.policy = previous_policy
else:
gen_algo = rl.load_rl_algo_from_path(
agent_path=agent_path,
Expand Down
32 changes: 32 additions & 0 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,38 @@ def dagger(
return stats


def warm_start_with_bc(bc_config: Optional[Mapping[str, Any]]) -> str:
"""Used if one wants to pre-train a model with behavior cloning.

Args:
bc_config: map of the settings to run a behavior cloning experiment. There
should be two keys: "config_updates" and "named_configs" with each
corresponding to the terms one wants to pass through to the behavior
cloning training run.

Returns:
The path to where the pre-trained model is saved.
"""
bc_config = bc_config if bc_config is not None else {}

config_updates: Optional[Dict[Any, Any]] = {}
if "config_updates" in bc_config:
config_updates = bc_config["config_updates"]

named_configs: Sequence[str] = []
if "named_configs" in bc_config:
named_configs = bc_config["named_configs"]

train_imitation_ex.run(
command_name="bc",
named_configs=named_configs,
config_updates=config_updates,
)

_, log_dir = logging_ingredient.setup_logging()
return osp.join(log_dir, "final.th")


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_dagger"
observer = FileStorageObserver(observer_path)
Expand Down
30 changes: 30 additions & 0 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,36 @@ def test_train_adversarial_warmstart(tmpdir, command):
_check_train_ex_result(run_warmstart.result)


@pytest.mark.parametrize("command", ("airl", "gail"))
def test_train_adversarial_warmstart_with_bc(tmpdir, command):
"""Test of warmstarts w/ behavior cloning before adversarial training."""
bc_config_updates = dict(
logging=dict(log_root=tmpdir),
expert=dict(policy_type="ppo-huggingface"),
demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
)
bc_named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["imitation"]

named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"]

config_updates = {
"logging": dict(log_root=tmpdir),
"demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
"warm_start_with_bc": True,
"bc_config": {
"config_updates": bc_config_updates,
"named_configs": bc_named_configs,
},
}
run = train_adversarial.train_adversarial_ex.run(
command_name=command,
named_configs=named_configs,
config_updates=config_updates,
)
assert run.status == "COMPLETED"
_check_train_ex_result(run.result)


@pytest.mark.parametrize("command", ("airl", "gail"))
def test_train_adversarial_sac(tmpdir, command):
"""Smoke test for imitation.scripts.train_adversarial."""
Expand Down