From 24a88b8ee10aa1864363555076bbcc53f1eb0506 Mon Sep 17 00:00:00 2001 From: samuelarnesen Date: Fri, 28 Oct 2022 23:25:43 -0400 Subject: [PATCH] enable warm starting adversarial algorithms w/ behavior cloning --- .../scripts/config/train_adversarial.py | 3 ++ src/imitation/scripts/train_adversarial.py | 29 +++++++++++++---- src/imitation/scripts/train_imitation.py | 32 +++++++++++++++++++ tests/scripts/test_scripts.py | 30 +++++++++++++++++ 4 files changed, 88 insertions(+), 6 deletions(-) diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 55e6effec..ddfb73841 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -34,6 +34,9 @@ def defaults(): checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) agent_path = None # Path to load agent from, optional. + warm_start_with_bc = False # default to not warmstarting with behavior cloning + bc_config = None # No config for behavior cloning if not warm-starting + device = "auto" # Device needed only if warmstarting w/ behavior cloning @train_adversarial_ex.config diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 26c8d7bcf..4c98b4cf7 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -3,12 +3,14 @@ import functools import logging import pathlib -from typing import Any, Mapping, Optional, Type +from typing import Any, Mapping, Optional, Type, Union import sacred.commands import torch as th from sacred.observers import FileStorageObserver +from stable_baselines3.common import utils +import imitation.scripts.train_imitation as train_imitation from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common from imitation.algorithms.adversarial import gail as gail_algo @@ -73,6 +75,9 @@ def train_adversarial( total_timesteps: int, checkpoint_interval: int, agent_path: Optional[str], + warm_start_with_bc: bool, + bc_config: Optional[Mapping[str, Any]], + device: Union[str, th.device], ) -> Mapping[str, Mapping[str, float]]: """Train an adversarial-network-based imitation learning algorithm. @@ -98,6 +103,13 @@ def train_adversarial( provided, then the agent will be initialized using this stored policy (warm start). If not provided, then the agent will be initialized using a random policy. + warm_start_with_bc: boolean indicates whether one should pre-train using + behavior cloning before using one of the adversarial algorithms + bc_config: Only applies if warm_start_with_bc=True. These are the settings + that govern the pre-training w/ behavior cloning. See the documentation + for behavior cloning for all the (optional) individual parameters. + device: Only needed if warm_start_with_bc is true. This is the device that + the training is running on. Defaults to "auto". Returns: A dictionary with two keys. "imit_stats" gives the return value of @@ -106,11 +118,6 @@ def train_adversarial( "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations. """ - # This allows to specify total_timesteps and checkpoint_interval in scientific - # notation, which is interpreted as a float by python. - total_timesteps = int(total_timesteps) - checkpoint_interval = int(checkpoint_interval) - if show_config: # Running `train_adversarial print_config` will show unmerged config. # So, support showing merged config from `train_adversarial {airl,gail}`. @@ -119,6 +126,10 @@ def train_adversarial( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() + previous_policy_path = None + if warm_start_with_bc: + previous_policy_path = train_imitation.warm_start_with_bc(bc_config=bc_config) + with environment.make_venv() as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( @@ -128,6 +139,12 @@ def train_adversarial( if agent_path is None: gen_algo = rl.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) + if previous_policy_path is not None: + previous_policy = th.load( + previous_policy_path, + map_location=utils.get_device(device), + ) + gen_algo.policy = previous_policy else: gen_algo = rl.load_rl_algo_from_path( agent_path=agent_path, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index e607339b4..02a0c160d 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -140,6 +140,38 @@ def dagger( return stats +def warm_start_with_bc(bc_config: Optional[Mapping[str, Any]]) -> str: + """Used if one wants to pre-train a model with behavior cloning. + + Args: + bc_config: map of the settings to run a behavior cloning experiment. There + should be two keys: "config_updates" and "named_configs" with each + corresponding to the terms one wants to pass through to the behavior + cloning training run. + + Returns: + The path to where the pre-trained model is saved. + """ + bc_config = bc_config if bc_config is not None else {} + + config_updates: Optional[Dict[Any, Any]] = {} + if "config_updates" in bc_config: + config_updates = bc_config["config_updates"] + + named_configs: Sequence[str] = [] + if "named_configs" in bc_config: + named_configs = bc_config["named_configs"] + + train_imitation_ex.run( + command_name="bc", + named_configs=named_configs, + config_updates=config_updates, + ) + + _, log_dir = logging_ingredient.setup_logging() + return osp.join(log_dir, "final.th") + + def main_console(): observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_dagger" observer = FileStorageObserver(observer_path) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index c47335ef7..52e3d57fa 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -598,6 +598,36 @@ def test_train_adversarial_warmstart(tmpdir, command): _check_train_ex_result(run_warmstart.result) +@pytest.mark.parametrize("command", ("airl", "gail")) +def test_train_adversarial_warmstart_with_bc(tmpdir, command): + """Test of warmstarts w/ behavior cloning before adversarial training.""" + bc_config_updates = dict( + logging=dict(log_root=tmpdir), + expert=dict(policy_type="ppo-huggingface"), + demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + ) + bc_named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["imitation"] + + named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"] + + config_updates = { + "logging": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "warm_start_with_bc": True, + "bc_config": { + "config_updates": bc_config_updates, + "named_configs": bc_named_configs, + }, + } + run = train_adversarial.train_adversarial_ex.run( + command_name=command, + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_train_ex_result(run.result) + + @pytest.mark.parametrize("command", ("airl", "gail")) def test_train_adversarial_sac(tmpdir, command): """Smoke test for imitation.scripts.train_adversarial."""