Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Added expectation advantage_type to CRR #26142

Merged
merged 3 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ py_test(
args = ["--yaml-dir=tuned_examples/crr", '--framework=torch']
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

name = "learning_tests_cartpole_crr_expectation",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include an offline json data file as well.
data = [
"tuned_examples/crr/cartpole-v0-crr_expectation.yaml",
"tests/data/cartpole/large.json",
],
args = ["--yaml-dir=tuned_examples/crr", '--framework=torch']
)

# DDPG
# py_test(
# name = "learning_tests_pendulum_ddpg",
Expand Down
21 changes: 20 additions & 1 deletion rllib/algorithms/crr/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,26 @@ def training(
weight_type: weight type to use `bin` | `exp`.
temperature: the exponent temperature used in exp weight type.
max_weight: the max weight limit for exp weight type.
advantage_type: The way we reduce q values to v_t values `max` | `mean`.
advantage_type: The way we reduce q values to v_t values
`max` | `mean` | `expectation`. `max` and `mean` work for both
discrete and continuous action spaces while `expectation` only
works for discrete action spaces.
`max`: Uses max over sampled actions to estimate the value.
.. math::
A(s_t, a_t) = Q(s_t, a_t) - \max_{a^j} Q(s_t, a^j)
where :math:a^j is `n_action_sample` times sampled from the
policy :math:\pi(a | s_t)
`mean`: Uses mean over sampled actions to estimate the value.
.. math::
A(s_t, a_t) = Q(s_t, a_t) - \frac{1}{m}\sum_{j=1}^{m}[Q
(s_t, a^j)]
where :math:a^j is `n_action_sample` times sampled from the
policy :math:\pi(a | s_t)
`expectation`: This uses categorical distribution to evaluate
the expectation of the q values directly to estimate the value.
.. math::
A(s_t, a_t) = Q(s_t, a_t) - E_{a^j\sim \pi(a|s_t)}[Q(s_t,
a^j)]
n_action_sample: the number of actions to sample for v_t estimation.
twin_q: if True, uses pessimistic q estimation.
target_update_grad_intervals: The frequency at which we update the
Expand Down
82 changes: 53 additions & 29 deletions rllib/algorithms/crr/torch/crr_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,24 @@ def _compute_adv_and_logps(
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> None:
# uses mean|max to compute estimate of advantages
# uses mean|max|expectation to compute estimate of advantages
# continuous/discrete action spaces:
# for max:
# A(s_t, a_t) = Q(s_t, a_t) - max_{a^j} Q(s_t, a^j)
# where a^j is m times sampled from the policy p(a | s_t)
# for mean:
# A(s_t, a_t) = Q(s_t, a_t) - avg( Q(s_t, a^j) )
# where a^j is m times sampled from the policy p(a | s_t)
# questions: Do we use pessimistic q approximate or the normal one?
# discrete action space and adv_type=expectation:
# A(s_t, a_t) = Q(s_t, a_t) - sum_j[Q(s_t, a^j) * pi(a^j)]
advantage_type = self.config["advantage_type"]
n_action_sample = self.config["n_action_sample"]
batch_size = len(train_batch)
out_t, _ = model(train_batch)

# construct pi(s_t) for sampling actions
# construct pi(s_t) and Q(s_t, a_t) for computing advantage actions
pi_s_t = dist_class(model.get_policy_output(out_t), model)
policy_actions = pi_s_t.dist.sample((n_action_sample,)) # samples

if self._is_action_discrete:
flat_actions = policy_actions.reshape(-1)
else:
flat_actions = policy_actions.reshape(-1, *self.action_space.shape)
q_t = self._get_q_value(model, out_t, train_batch[SampleBatch.ACTIONS])

# compute the logp of the actions in the dataset (for computing actor's loss)
action_logp = pi_s_t.dist.log_prob(train_batch[SampleBatch.ACTIONS])
Expand All @@ -231,30 +229,56 @@ def _compute_adv_and_logps(
action_logp.unsqueeze_(-1)
train_batch[SampleBatch.ACTION_LOGP] = action_logp

reshaped_s_t = train_batch[SampleBatch.OBS].view(
1, batch_size, *self.observation_space.shape
)
reshaped_s_t = reshaped_s_t.expand(
n_action_sample, batch_size, *self.observation_space.shape
)
flat_s_t = reshaped_s_t.reshape(-1, *self.observation_space.shape)
if advantage_type == "expectation":
assert (
self._is_action_discrete
), "Action space should be discrete when advantage_type = expectation."
assert hasattr(
self.model, "q_model"
), "CRR's ModelV2 should have q_model neural network in discrete \
action spaces"
assert isinstance(
pi_s_t.dist, torch.distributions.Categorical
), "The output of the policy should be a torch Categorical \
distribution."

q_vals = self.model.q_model(out_t)
if hasattr(self.model, "twin_q_model"):
q_twins = self.model.twin_q_model(out_t)
q_vals = torch.minimum(q_vals, q_twins)

probs = pi_s_t.dist.probs
v_t = (q_t * probs).sum(-1, keepdims=True)
else:
policy_actions = pi_s_t.dist.sample((n_action_sample,)) # samples

input_v_t = SampleBatch(
**{SampleBatch.OBS: flat_s_t, SampleBatch.ACTIONS: flat_actions}
)
out_v_t, _ = model(input_v_t)
if self._is_action_discrete:
flat_actions = policy_actions.reshape(-1)
else:
flat_actions = policy_actions.reshape(-1, *self.action_space.shape)

flat_q_st_pi = self._get_q_value(model, out_v_t, flat_actions)
reshaped_q_st_pi = flat_q_st_pi.reshape(-1, batch_size, 1)
reshaped_s_t = train_batch[SampleBatch.OBS].view(
1, batch_size, *self.observation_space.shape
)
reshaped_s_t = reshaped_s_t.expand(
n_action_sample, batch_size, *self.observation_space.shape
)
flat_s_t = reshaped_s_t.reshape(-1, *self.observation_space.shape)

if advantage_type == "mean":
v_t = reshaped_q_st_pi.mean(dim=0)
elif advantage_type == "max":
v_t, _ = reshaped_q_st_pi.max(dim=0)
else:
raise ValueError(f"Invalid advantage type: {advantage_type}.")
input_v_t = SampleBatch(
**{SampleBatch.OBS: flat_s_t, SampleBatch.ACTIONS: flat_actions}
)
out_v_t, _ = model(input_v_t)

q_t = self._get_q_value(model, out_t, train_batch[SampleBatch.ACTIONS])
flat_q_st_pi = self._get_q_value(model, out_v_t, flat_actions)
reshaped_q_st_pi = flat_q_st_pi.reshape(-1, batch_size, 1)

if advantage_type == "mean":
v_t = reshaped_q_st_pi.mean(dim=0)
elif advantage_type == "max":
v_t, _ = reshaped_q_st_pi.max(dim=0)
else:
raise ValueError(f"Invalid advantage type: {advantage_type}.")

adv_t = q_t - v_t
train_batch["advantages"] = adv_t
Expand Down
3 changes: 2 additions & 1 deletion rllib/tuned_examples/crr/cartpole-v0-crr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ cartpole_crr:
learning_starts: 0
capacity: 100000
# specific to CRR
temperature: 1.0
weight_type: bin
advantage_type: max
advantage_type: mean
max_weight: 20.0
n_action_sample: 4
44 changes: 44 additions & 0 deletions rllib/tuned_examples/crr/cartpole-v0-crr_expectation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
cartpole_crr:
env: 'CartPole-v0'
run: CRR
stop:
evaluation/episode_reward_mean: 200
training_iteration: 100
config:
input:
- 'tests/data/cartpole/large.json'
framework: torch
gamma: 0.99
train_batch_size: 2048
critic_hidden_activation: 'tanh'
critic_hiddens: [128, 128, 128]
critic_lr: 0.0003
actor_hidden_activation: 'tanh'
actor_hiddens: [128, 128, 128]
actor_lr: 0.0003
actions_in_input_normalized: True
clip_actions: True
# Q function update setting
twin_q: True
target_update_grad_intervals: 1
tau: 0.0005
# evaluation
evaluation_config:
explore: False
input: sampler
evaluation_duration: 10
evaluation_duration_unit: episodes
evaluation_interval: 1
evaluation_num_workers: 1
evaluation_parallel_to_training: True
# replay buffer
replay_buffer_config:
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
learning_starts: 0
capacity: 100000
# specific to CRR
temperature: 1.0
weight_type: bin
advantage_type: expectation
max_weight: 20.0
n_action_sample: 4
1 change: 1 addition & 0 deletions rllib/tuned_examples/crr/pendulum-v1-crr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pendulum_crr:
learning_starts: 0
capacity: 100000
# specific to CRR
temperature: 1.0
weight_type: exp
advantage_type: max
max_weight: 20.0
Expand Down