Skip to content

Commit

Permalink
RNN support in DDPG and SAC
Browse files Browse the repository at this point in the history
  • Loading branch information
Edilmo committed Oct 6, 2020
1 parent a08fc78 commit 5d19589
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 47 deletions.
32 changes: 27 additions & 5 deletions rllib/agents/ddpg/ddpg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,16 @@ def build_ddpg_models(policy, observation_space, action_space, config):
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
state_batches,
seq_lens,
*,
explore=True,
is_training=False,
**kwargs):
model_out, _ = model({
"obs": obs_batch,
"is_training": is_training,
}, [], None)
}, state_batches, seq_lens)
dist_inputs = model.get_policy_output(model_out)

return dist_inputs, (TorchDeterministic
Expand All @@ -128,6 +130,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
huber_threshold = policy.config["huber_threshold"]
l2_reg = policy.config["l2_reg"]

states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []

input_dict = {
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": True,
Expand All @@ -137,9 +151,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
"is_training": True,
}

model_out_t, _ = model(input_dict, [], None)
model_out_tp1, _ = model(input_dict_next, [], None)
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
model_out_t, _ = model(input_dict, states_in, seq_lens)
model_out_tp1, _ = model(input_dict_next, states_out, seq_lens)
target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens)

# Policy network evaluation.
with tf.variable_scope(POLICY_SCOPE, reuse=True):
Expand Down Expand Up @@ -246,6 +260,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
for i, h in enumerate(states_in):
input_dict["state_in_{}".format(i)] = h
for i, h in enumerate(states_out):
input_dict["state_out_{}".format(i)] = h
if log_once("ddpg_custom_loss"):
logger.warning(
"You are using a state-preprocessor with DDPG and "
Expand Down Expand Up @@ -362,7 +380,10 @@ class ComputeTDErrorMixin:
def __init__(self, loss_fn):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
# Do forward pass on loss to update td errors attribute
# (one TD-error value per item in batch to update PR weights).
loss_fn(
Expand All @@ -373,6 +394,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
**state_in_out_and_seq_lens
})
# `self.td_error` is set in loss_fn.
return self.td_error
Expand Down
28 changes: 24 additions & 4 deletions rllib/agents/ddpg/ddpg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
huber_threshold = policy.config["huber_threshold"]
l2_reg = policy.config["l2_reg"]

states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []

input_dict = {
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": True,
Expand All @@ -44,9 +56,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
"is_training": True,
}

model_out_t, _ = model(input_dict, [], None)
model_out_tp1, _ = model(input_dict_next, [], None)
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
model_out_t, _ = model(input_dict, states_in, seq_lens)
model_out_tp1, _ = model(input_dict_next, states_out, seq_lens)
target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens)

# Policy network evaluation.
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Expand Down Expand Up @@ -146,6 +158,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
for i, h in enumerate(states_in):
input_dict["state_in_{}".format(i)] = h
for i, h in enumerate(states_out):
input_dict["state_out_{}".format(i)] = h
[actor_loss, critic_loss] = model.custom_loss(
[actor_loss, critic_loss], input_dict)

Expand Down Expand Up @@ -214,7 +230,10 @@ def before_init_fn(policy, obs_space, action_space, config):
class ComputeTDErrorMixin:
def __init__(self, loss_fn):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_t,
SampleBatch.ACTIONS: act_t,
Expand All @@ -223,6 +242,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
SampleBatch.DONES: done_mask,
PRIO_WEIGHTS: importance_weights,
})
input_dict.update(state_in_out_and_seq_lens)
# Do forward pass on loss to update td errors attribute
# (one TD-error value per item in batch to update PR weights).
loss_fn(self, self.model, None, input_dict)
Expand Down
43 changes: 25 additions & 18 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,11 @@ def __init__(self,
class ComputeTDErrorMixin:
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, states_t, act_t, rew_t, obs_tp1, done_mask,
states_tp1, seq_lens, importance_weights):
state_in_out_and_seq_lens = {}
for i, h in enumerate(states_t):
state_in_out_and_seq_lens["state_in_{}".format(i)] = h
for i, h in enumerate(states_tp1):
state_in_out_and_seq_lens["state_out_{}".format(i)] = h
if state_in_out_and_seq_lens:
state_in_out_and_seq_lens["seq_lens"] = seq_lens
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
# Do forward pass on loss to update td error attribute
build_q_losses(
self, self.model, None, {
Expand Down Expand Up @@ -222,7 +218,7 @@ def build_q_losses(policy, model, _, train_batch):
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else None
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.array([])
config = policy.config
# q network evaluation
q_t, q_logits_t, q_dist_t = compute_q_values(
Expand Down Expand Up @@ -323,6 +319,8 @@ def setup_late_mixins(policy, obs_space, action_space, config):
def compute_q_values(policy, model, obs, states, seq_lens, explore):
config = policy.config

if type(states) is tuple:
states = list(states)
model_out, state = model({
SampleBatch.CUR_OBS: obs,
"is_training": policy._get_is_training_placeholder(),
Expand Down Expand Up @@ -386,21 +384,30 @@ def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, done
new_obs[i] = new_obs[i + j]
dones[i] = dones[i + j]
rewards[i] += gamma**j * rewards[i + j]
states_out[i] = states_out[i + j]
if states_out:
states_out[i] = states_out[i + j]


def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
state_in_out_and_seq_lens = {}
states_in = []
i = 0
while "state_in_{}".format(i) in batch:
states_in.append(batch["state_in_{}".format(i)])
key ="state_in_{}".format(i)
while key in batch:
states_in.append(batch[key])
state_in_out_and_seq_lens[key] = batch[key]
i += 1
key ="state_in_{}".format(i)
states_out = []
i = 0
while "state_out_{}".format(i) in batch:
states_out.append(batch["state_out_{}".format(i)])
key ="state_out_{}".format(i)
while key in batch:
states_out.append(batch[key])
state_in_out_and_seq_lens[key] = batch[key]
i += 1
seq_lens = batch["seq_lens"] if "seq_lens" in batch else None
key = "state_out_{}".format(i)
seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.array([])
state_in_out_and_seq_lens["seq_lens"] = seq_lens
# N-step Q adjustments
if policy.config["n_step"] > 1:
_adjust_nstep(policy.config["n_step"], policy.config["gamma"],
Expand All @@ -419,9 +426,9 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
elif isinstance(policy.action_space, Discrete) and actions.shape[-1] == 1:
actions = np.reshape(actions, [-1])
td_errors = policy.compute_td_error(
batch[SampleBatch.CUR_OBS], states_in, actions,
batch[SampleBatch.CUR_OBS], actions,
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], states_out, seq_lens, batch[PRIO_WEIGHTS])
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS], **state_in_out_and_seq_lens)
new_priorities = (
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities
Expand Down
31 changes: 27 additions & 4 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ def __init__(self,
class ComputeTDErrorMixin:
def __init__(self):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
input_dict[SampleBatch.ACTIONS] = act_t
input_dict[SampleBatch.REWARDS] = rew_t
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
input_dict[SampleBatch.DONES] = done_mask
input_dict[PRIO_WEIGHTS] = importance_weights
input_dict.update(state_in_out_and_seq_lens)

# Do forward pass on loss to update td error attribute
build_q_losses(self, self.model, None, input_dict)
Expand Down Expand Up @@ -137,24 +141,39 @@ def build_q_model_and_distribution(policy, obs_space, action_space, config):
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
state_batches,
seq_lens,
*,
explore=True,
is_training=False,
**kwargs):
q_vals = compute_q_values(policy, model, obs_batch, explore, is_training)
q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore, is_training)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
return policy.q_values, TorchCategorical, [] # state-out


def build_q_losses(policy, model, _, train_batch):
states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []
config = policy.config
# q network evaluation
q_t = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
states_in,
seq_lens,
explore=False,
is_training=True)

Expand All @@ -163,6 +182,8 @@ def build_q_losses(policy, model, _, train_batch):
policy,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
states_out,
seq_lens,
explore=False,
is_training=True)

Expand All @@ -177,6 +198,8 @@ def build_q_losses(policy, model, _, train_batch):
policy,
policy.q_model,
train_batch[SampleBatch.NEXT_OBS],
states_out,
seq_lens,
explore=False,
is_training=True)
q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
Expand Down Expand Up @@ -221,14 +244,14 @@ def after_init(policy, obs_space, action_space, config):
policy.target_q_model = policy.target_q_model.to(policy.device)


def compute_q_values(policy, model, obs, explore, is_training=False):
def compute_q_values(policy, model, obs, states, seq_lens, explore, is_training=False):
if policy.config["num_atoms"] > 1:
raise ValueError("torch DQN does not support distributional DQN yet!")

model_out, state = model({
SampleBatch.CUR_OBS: obs,
"is_training": is_training,
}, [], None)
}, states, seq_lens)

advantages_or_q_values = model.get_advantages_or_q_values(model_out)

Expand Down
22 changes: 18 additions & 4 deletions rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,16 @@ def get_dist_class(config, action_space):
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
state_batches,
seq_lens,
*,
explore=True,
**kwargs):
# Get base-model output.
model_out, state_out = model({
"obs": obs_batch,
"is_training": policy._get_is_training_placeholder(),
}, [], None)
}, state_batches, seq_lens)
# Get action model output from base-model output.
distribution_inputs = model.get_policy_output(model_out)
action_dist_class = get_dist_class(policy.config, policy.action_space)
Expand All @@ -134,20 +136,32 @@ def sac_actor_critic_loss(policy, model, _, train_batch):
# Should be True only for debugging purposes (e.g. test cases)!
deterministic = policy.config["_deterministic_loss"]

states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []

model_out_t, _ = model({
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
}, states_in, seq_lens)

model_out_tp1, _ = model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
}, states_out, seq_lens)

target_model_out_tp1, _ = policy.target_model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
}, states_out, seq_lens)

# Discrete case.
if model.discrete:
Expand Down
Loading

0 comments on commit 5d19589

Please sign in to comment.