From 5d195896d81576d1e9c203c55912fc37eb42074a Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Mon, 5 Oct 2020 19:56:47 -0700 Subject: [PATCH] RNN support in DDPG and SAC --- rllib/agents/ddpg/ddpg_tf_policy.py | 32 ++++++++++++++++--- rllib/agents/ddpg/ddpg_torch_policy.py | 28 ++++++++++++++--- rllib/agents/dqn/dqn_tf_policy.py | 43 +++++++++++++++----------- rllib/agents/dqn/dqn_torch_policy.py | 31 ++++++++++++++++--- rllib/agents/sac/sac_tf_policy.py | 22 ++++++++++--- rllib/agents/sac/sac_torch_policy.py | 30 +++++++++++++----- rllib/policy/torch_policy.py | 2 ++ rllib/utils/tf_ops.py | 23 +++++++++++--- 8 files changed, 164 insertions(+), 47 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index b57b53de1b68..8ca865b04893 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -105,6 +105,8 @@ def build_ddpg_models(policy, observation_space, action_space, config): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, is_training=False, @@ -112,7 +114,7 @@ def get_distribution_inputs_and_class(policy, model_out, _ = model({ "obs": obs_batch, "is_training": is_training, - }, [], None) + }, state_batches, seq_lens) dist_inputs = model.get_policy_output(model_out) return dist_inputs, (TorchDeterministic @@ -128,6 +130,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, @@ -137,9 +151,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, [], None) - model_out_tp1, _ = model(input_dict_next, [], None) - target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) + model_out_t, _ = model(input_dict, states_in, seq_lens) + model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. with tf.variable_scope(POLICY_SCOPE, reuse=True): @@ -246,6 +260,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] + for i, h in enumerate(states_in): + input_dict["state_in_{}".format(i)] = h + for i, h in enumerate(states_out): + input_dict["state_out_{}".format(i)] = h if log_once("ddpg_custom_loss"): logger.warning( "You are using a state-preprocessor with DDPG and " @@ -362,7 +380,10 @@ class ComputeTDErrorMixin: def __init__(self, loss_fn): @make_tf_callable(self.get_session(), dynamic_shape=True) def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). loss_fn( @@ -373,6 +394,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), SampleBatch.DONES: tf.convert_to_tensor(done_mask), PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), + **state_in_out_and_seq_lens }) # `self.td_error` is set in loss_fn. return self.td_error diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 1c8d1d7a9bfa..6337e8ec14a5 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -35,6 +35,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, @@ -44,9 +56,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, [], None) - model_out_tp1, _ = model(input_dict_next, [], None) - target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) + model_out_t, _ = model(input_dict, states_in, seq_lens) + model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) @@ -146,6 +158,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] + for i, h in enumerate(states_in): + input_dict["state_in_{}".format(i)] = h + for i, h in enumerate(states_out): + input_dict["state_out_{}".format(i)] = h [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict) @@ -214,7 +230,10 @@ def before_init_fn(policy, obs_space, action_space, config): class ComputeTDErrorMixin: def __init__(self, loss_fn): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: obs_t, SampleBatch.ACTIONS: act_t, @@ -223,6 +242,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.DONES: done_mask, PRIO_WEIGHTS: importance_weights, }) + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). loss_fn(self, self.model, None, input_dict) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 0802ae284f8f..fe975784aa93 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -111,15 +111,11 @@ def __init__(self, class ComputeTDErrorMixin: def __init__(self): @make_tf_callable(self.get_session(), dynamic_shape=True) - def compute_td_error(obs_t, states_t, act_t, rew_t, obs_tp1, done_mask, - states_tp1, seq_lens, importance_weights): - state_in_out_and_seq_lens = {} - for i, h in enumerate(states_t): - state_in_out_and_seq_lens["state_in_{}".format(i)] = h - for i, h in enumerate(states_tp1): - state_in_out_and_seq_lens["state_out_{}".format(i)] = h - if state_in_out_and_seq_lens: - state_in_out_and_seq_lens["seq_lens"] = seq_lens + def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs # Do forward pass on loss to update td error attribute build_q_losses( self, self.model, None, { @@ -222,7 +218,7 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else None + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.array([]) config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t = compute_q_values( @@ -323,6 +319,8 @@ def setup_late_mixins(policy, obs_space, action_space, config): def compute_q_values(policy, model, obs, states, seq_lens, explore): config = policy.config + if type(states) is tuple: + states = list(states) model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": policy._get_is_training_placeholder(), @@ -386,21 +384,30 @@ def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, done new_obs[i] = new_obs[i + j] dones[i] = dones[i + j] rewards[i] += gamma**j * rewards[i + j] - states_out[i] = states_out[i + j] + if states_out: + states_out[i] = states_out[i + j] def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): + state_in_out_and_seq_lens = {} states_in = [] i = 0 - while "state_in_{}".format(i) in batch: - states_in.append(batch["state_in_{}".format(i)]) + key ="state_in_{}".format(i) + while key in batch: + states_in.append(batch[key]) + state_in_out_and_seq_lens[key] = batch[key] i += 1 + key ="state_in_{}".format(i) states_out = [] i = 0 - while "state_out_{}".format(i) in batch: - states_out.append(batch["state_out_{}".format(i)]) + key ="state_out_{}".format(i) + while key in batch: + states_out.append(batch[key]) + state_in_out_and_seq_lens[key] = batch[key] i += 1 - seq_lens = batch["seq_lens"] if "seq_lens" in batch else None + key = "state_out_{}".format(i) + seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.array([]) + state_in_out_and_seq_lens["seq_lens"] = seq_lens # N-step Q adjustments if policy.config["n_step"] > 1: _adjust_nstep(policy.config["n_step"], policy.config["gamma"], @@ -419,9 +426,9 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): elif isinstance(policy.action_space, Discrete) and actions.shape[-1] == 1: actions = np.reshape(actions, [-1]) td_errors = policy.compute_td_error( - batch[SampleBatch.CUR_OBS], states_in, actions, + batch[SampleBatch.CUR_OBS], actions, batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], - batch[SampleBatch.DONES], states_out, seq_lens, batch[PRIO_WEIGHTS]) + batch[SampleBatch.DONES], batch[PRIO_WEIGHTS], **state_in_out_and_seq_lens) new_priorities = ( np.abs(td_errors) + policy.config["prioritized_replay_eps"]) batch.data[PRIO_WEIGHTS] = new_priorities diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 22bf7905f70e..f57a85d4bee1 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -60,13 +60,17 @@ def __init__(self, class ComputeTDErrorMixin: def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t}) input_dict[SampleBatch.ACTIONS] = act_t input_dict[SampleBatch.REWARDS] = rew_t input_dict[SampleBatch.NEXT_OBS] = obs_tp1 input_dict[SampleBatch.DONES] = done_mask input_dict[PRIO_WEIGHTS] = importance_weights + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) @@ -137,11 +141,13 @@ def build_q_model_and_distribution(policy, obs_space, action_space, config): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, is_training=False, **kwargs): - q_vals = compute_q_values(policy, model, obs_batch, explore, is_training) + q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals @@ -149,12 +155,25 @@ def get_distribution_inputs_and_class(policy, def build_q_losses(policy, model, _, train_batch): + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] config = policy.config # q network evaluation q_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], + states_in, + seq_lens, explore=False, is_training=True) @@ -163,6 +182,8 @@ def build_q_losses(policy, model, _, train_batch): policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False, is_training=True) @@ -177,6 +198,8 @@ def build_q_losses(policy, model, _, train_batch): policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False, is_training=True) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) @@ -221,14 +244,14 @@ def after_init(policy, obs_space, action_space, config): policy.target_q_model = policy.target_q_model.to(policy.device) -def compute_q_values(policy, model, obs, explore, is_training=False): +def compute_q_values(policy, model, obs, states, seq_lens, explore, is_training=False): if policy.config["num_atoms"] > 1: raise ValueError("torch DQN does not support distributional DQN yet!") model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": is_training, - }, [], None) + }, states, seq_lens) advantages_or_q_values = model.get_advantages_or_q_values(model_out) diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 94cb2553b26f..19b7583aa9d9 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -116,6 +116,8 @@ def get_dist_class(config, action_space): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, **kwargs): @@ -123,7 +125,7 @@ def get_distribution_inputs_and_class(policy, model_out, state_out = model({ "obs": obs_batch, "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, state_batches, seq_lens) # Get action model output from base-model output. distribution_inputs = model.get_policy_output(model_out) action_dist_class = get_dist_class(policy.config, policy.action_space) @@ -134,20 +136,32 @@ def sac_actor_critic_loss(policy, model, _, train_batch): # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_in, seq_lens) model_out_tp1, _ = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_out, seq_lens) target_model_out_tp1, _ = policy.target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_out, seq_lens) # Discrete case. if model.discrete: diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 503bc4979498..3660540fe64d 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -39,9 +39,9 @@ def get_dist_class(config, action_space): def action_distribution_fn(policy, model, obs_batch, + state_batches, + seq_lens, *, - state_batches=None, - seq_lens=None, prev_action_batch=None, prev_reward_batch=None, explore=None, @@ -50,7 +50,7 @@ def action_distribution_fn(policy, model_out, _ = model({ "obs": obs_batch, "is_training": is_training, - }, [], None) + }, state_batches, seq_lens) distribution_inputs = model.get_policy_output(model_out) action_dist_class = get_dist_class(policy.config, policy.action_space) @@ -61,20 +61,32 @@ def actor_critic_loss(policy, model, _, train_batch): # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, - }, [], None) + }, states_in, seq_lens) model_out_tp1, _ = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, - }, [], None) + }, states_out, seq_lens) target_model_out_tp1, _ = policy.target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, - }, [], None) + }, states_out, seq_lens) alpha = torch.exp(model.log_alpha) @@ -280,7 +292,10 @@ def optimizer_fn(policy, config): class ComputeTDErrorMixin: def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: obs_t, SampleBatch.ACTIONS: act_t, @@ -289,6 +304,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.DONES: done_mask, PRIO_WEIGHTS: importance_weights, }) + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). actor_critic_loss(self, self.model, None, input_dict) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index c13ff5f8afe5..13562158705d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -142,6 +142,8 @@ def compute_actions(self, self, self.model, input_dict[SampleBatch.CUR_OBS], + state_batches=state_batches, + seq_lens=seq_lens, explore=explore, timestep=timestep, is_training=False) diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index ea636f744f04..e8871e947fb6 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -58,10 +58,11 @@ def make_tf_callable(session_or_none, dynamic_shape=False): def make_wrapper(fn): if session_or_none: - placeholders = [] + args_placeholders = [] + kwargs_placeholders = {} symbolic_out = [None] - def call(*args): + def call(*args, **kwargs): args_flat = [] for a in args: if type(a) is list: @@ -79,13 +80,25 @@ def call(*args): shape = () else: shape = v.shape - placeholders.append( + args_placeholders.append( tf.placeholder( dtype=v.dtype, shape=shape, name="arg_{}".format(i))) - symbolic_out[0] = fn(*placeholders) - feed_dict = dict(zip(placeholders, args)) + for k, v in kwargs.items(): + if dynamic_shape: + if len(v.shape) > 0: + shape = (None, ) + v.shape[1:] + else: + shape = () + else: + shape = v.shape + kwargs_placeholders[k] = tf.placeholder( + dtype=v.dtype, + shape=shape, + name="karg_{}".format(k)) + symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, args)) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret