Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State available in SampleBatch and ReplayBuffer #43

Open
wants to merge 6 commits into
base: releases/0.8.6
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions rllib/agents/ddpg/ddpg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,21 @@ def build_ddpg_models(policy, observation_space, action_space, config):
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
state_batches,
seq_lens,
*,
explore=True,
is_training=False,
**kwargs):
model_out, _ = model({
model_out, state_out = model({
"obs": obs_batch,
"is_training": is_training,
}, [], None)
}, state_batches, seq_lens)
dist_inputs = model.get_policy_output(model_out)

return dist_inputs, (TorchDeterministic
if policy.config["framework"] == "torch" else
Deterministic), [] # []=state out
Deterministic), state_out # state out


def ddpg_actor_critic_loss(policy, model, _, train_batch):
Expand All @@ -128,6 +130,21 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
huber_threshold = policy.config["huber_threshold"]
l2_reg = policy.config["l2_reg"]

states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0]
if isinstance(train_batch[SampleBatch.CUR_OBS], tf.Tensor)
else len(train_batch[SampleBatch.CUR_OBS]))
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size)

input_dict = {
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": True,
Expand All @@ -137,9 +154,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
"is_training": True,
}

model_out_t, _ = model(input_dict, [], None)
model_out_tp1, _ = model(input_dict_next, [], None)
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
model_out_t, states_out_t = model(input_dict, states_in, seq_lens)
model_out_tp1, states_out_tp1 = model(input_dict_next, states_out, seq_lens)
target_model_out_tp1, target_states_out_tp1 = policy.target_model(input_dict_next, states_out, seq_lens)

# Policy network evaluation.
with tf.variable_scope(POLICY_SCOPE, reuse=True):
Expand Down Expand Up @@ -246,6 +263,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
for i, h in enumerate(states_in):
input_dict["state_in_{}".format(i)] = h
for i, h in enumerate(states_out):
input_dict["state_out_{}".format(i)] = h
if log_once("ddpg_custom_loss"):
logger.warning(
"You are using a state-preprocessor with DDPG and "
Expand Down Expand Up @@ -362,7 +383,10 @@ class ComputeTDErrorMixin:
def __init__(self, loss_fn):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
# Do forward pass on loss to update td errors attribute
# (one TD-error value per item in batch to update PR weights).
loss_fn(
Expand All @@ -373,6 +397,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
**state_in_out_and_seq_lens
})
# `self.td_error` is set in loss_fn.
return self.td_error
Expand Down
32 changes: 28 additions & 4 deletions rllib/agents/ddpg/ddpg_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import numpy as np

import ray
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
Expand Down Expand Up @@ -35,6 +36,21 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
huber_threshold = policy.config["huber_threshold"]
l2_reg = policy.config["l2_reg"]

states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0]
if isinstance(train_batch[SampleBatch.CUR_OBS], torch.Tensor)
else len(train_batch[SampleBatch.CUR_OBS]))
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size)

input_dict = {
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": True,
Expand All @@ -44,9 +60,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
"is_training": True,
}

model_out_t, _ = model(input_dict, [], None)
model_out_tp1, _ = model(input_dict_next, [], None)
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
model_out_t, states_out_t = model(input_dict, states_in, seq_lens)
model_out_tp1, states_out_tp1 = model(input_dict_next, states_out, seq_lens)
target_model_out_tp1, target_states_out_tp1 = policy.target_model(input_dict_next, states_out, seq_lens)

# Policy network evaluation.
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Expand Down Expand Up @@ -146,6 +162,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
for i, h in enumerate(states_in):
input_dict["state_in_{}".format(i)] = h
for i, h in enumerate(states_out):
input_dict["state_out_{}".format(i)] = h
[actor_loss, critic_loss] = model.custom_loss(
[actor_loss, critic_loss], input_dict)

Expand Down Expand Up @@ -214,7 +234,10 @@ def before_init_fn(policy, obs_space, action_space, config):
class ComputeTDErrorMixin:
def __init__(self, loss_fn):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_t,
SampleBatch.ACTIONS: act_t,
Expand All @@ -223,6 +246,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
SampleBatch.DONES: done_mask,
PRIO_WEIGHTS: importance_weights,
})
input_dict.update(state_in_out_and_seq_lens)
# Do forward pass on loss to update td errors attribute
# (one TD-error value per item in batch to update PR weights).
loss_fn(self, self.model, None, input_dict)
Expand Down
79 changes: 66 additions & 13 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ class ComputeTDErrorMixin:
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
importance_weights, **kwargs):
# kwargs should contain only state input tensors,
# state output tensors and the seq_lens tensor
state_in_out_and_seq_lens = kwargs
# Do forward pass on loss to update td error attribute
build_q_losses(
self, self.model, None, {
Expand All @@ -122,6 +125,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
**state_in_out_and_seq_lens
})

return self.q_loss.td_error
Expand Down Expand Up @@ -190,31 +194,52 @@ def build_q_model(policy, obs_space, action_space, config):
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
state_batches,
seq_lens,
*,
explore=True,
**kwargs):
q_vals = compute_q_values(policy, model, obs_batch, explore)
q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore)
state_out = q_vals[3] if isinstance(q_vals, tuple) else []
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
policy.q_func_vars = model.variables()
return policy.q_values, Categorical, [] # state-out
return policy.q_values, Categorical, state_out # state-out


def build_q_losses(policy, model, _, train_batch):
states_in = []
i = 0
while "state_in_{}".format(i) in train_batch:
states_in.append(train_batch["state_in_{}".format(i)])
i += 1
states_out = []
i = 0
while "state_out_{}".format(i) in train_batch:
states_out.append(train_batch["state_out_{}".format(i)])
i += 1
batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0]
if isinstance(train_batch[SampleBatch.CUR_OBS], tf.Tensor)
else len(train_batch[SampleBatch.CUR_OBS]))
seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size)
config = policy.config
# q network evaluation
q_t, q_logits_t, q_dist_t = compute_q_values(
q_t, q_logits_t, q_dist_t, q_state_t = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
states_in,
seq_lens,
explore=False)

# target q network evalution
q_tp1, q_logits_tp1, q_dist_tp1 = compute_q_values(
q_tp1, q_logits_tp1, q_dist_tp1, q_state_tp1 = compute_q_values(
policy,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
states_out,
seq_lens,
explore=False)
policy.target_q_func_vars = policy.target_q_model.variables()

Expand All @@ -229,9 +254,11 @@ def build_q_losses(policy, model, _, train_batch):
# compute estimate of best possible value starting from state at t + 1
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net = compute_q_values(
q_dist_tp1_using_online_net, q_state_tp1_using_online_net = compute_q_values(
policy, policy.q_model,
train_batch[SampleBatch.NEXT_OBS],
states_out,
seq_lens,
explore=False)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
Expand Down Expand Up @@ -293,13 +320,15 @@ def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


def compute_q_values(policy, model, obs, explore):
def compute_q_values(policy, model, obs, states, seq_lens, explore):
config = policy.config

if type(states) is tuple:
states = list(states)
model_out, state = model({
SampleBatch.CUR_OBS: obs,
"is_training": policy._get_is_training_placeholder(),
}, [], None)
}, states, seq_lens)

if config["num_atoms"] > 1:
(action_scores, z, support_logits_per_action, logits,
Expand Down Expand Up @@ -332,10 +361,11 @@ def compute_q_values(policy, model, obs, explore):
else:
value = action_scores

return value, logits, dist
return value, logits, dist, state


def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, dones,
states_out, seq_lens):
"""Rewrites the given trajectory fragments to encode n-step rewards.

reward[i] = (
Expand All @@ -358,15 +388,38 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
new_obs[i] = new_obs[i + j]
dones[i] = dones[i + j]
rewards[i] += gamma**j * rewards[i + j]
if len(states_out) >= (i + j + 1):
states_out[i] = states_out[i + j]
elif len(states_out) >= (i + 1):
states_out[i] = states_out[-1]


def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
state_in_out_and_seq_lens = {}
states_in = []
i = 0
key ="state_in_{}".format(i)
while key in batch:
states_in.append(batch[key])
state_in_out_and_seq_lens[key] = batch[key]
i += 1
key ="state_in_{}".format(i)
states_out = []
i = 0
key ="state_out_{}".format(i)
while key in batch:
states_out.append(batch[key])
state_in_out_and_seq_lens[key] = batch[key]
i += 1
key = "state_out_{}".format(i)
seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.ones(len(batch[SampleBatch.CUR_OBS]))
state_in_out_and_seq_lens["seq_lens"] = seq_lens
# N-step Q adjustments
if policy.config["n_step"] > 1:
_adjust_nstep(policy.config["n_step"], policy.config["gamma"],
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.CUR_OBS], states_in, batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES])
batch[SampleBatch.DONES], states_out, seq_lens)

if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
Expand All @@ -381,7 +434,7 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
td_errors = policy.compute_td_error(
batch[SampleBatch.CUR_OBS], actions,
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS], **state_in_out_and_seq_lens)
new_priorities = (
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities
Expand Down
Loading