Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib]: Fix FQE Policy call #26671

Merged
merged 9 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions rllib/offline/estimators/fqe_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def train(self, batch: SampleBatch) -> TensorType:
A list of losses for each training iteration
"""
losses = []
if self.minibatch_size is None:
minibatch_size = batch.count
minibatch_size = self.minibatch_size or batch.count
# Copy batch for shuffling
batch = batch.copy(shallow=True)
for _ in range(self.n_iters):
minibatch_losses = []
batch.shuffle()
Expand Down Expand Up @@ -209,18 +210,29 @@ def _compute_action_probs(self, obs: TensorType) -> TensorType:
input_dict = {SampleBatch.OBS: obs}
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
state_batches = []
if self.policy.action_distribution_fn and is_overridden(
self.policy.action_distribution_fn
):
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy,
self.policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
if is_overridden(self.policy.action_distribution_fn):
try:
# TorchPolicyV2 function signature
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy.model,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
except TypeError:
# TorchPolicyV1 function signature for compatibility with DQN
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy,
self.policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
else:
dist_class = self.policy.dist_class
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
Expand Down
46 changes: 46 additions & 0 deletions rllib/offline/estimators/tests/test_ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.numpy import convert_to_numpy
from pathlib import Path
import os
import copy
import numpy as np
import gym
import torch


class TestOPE(unittest.TestCase):
Expand Down Expand Up @@ -162,6 +166,48 @@ def test_ope_in_algo(self):
print(*list(std_est.items()), sep="\n")
print("\n\n\n")

def test_fqe_model(self):
# Test FQETorchModel for:
# (1) Check that it does not modify the underlying batch during training
# (2) Check that the stoppign criteria from FQE are working correctly
# (3) Check that using fqe._compute_action_probs equals brute force
# iterating over all actions with policy.compute_log_likelihoods
fqe = FQETorchModel(
policy=self.algo.get_policy(),
gamma=self.gamma,
**self.q_model_config,
)
tmp_batch = copy.deepcopy(self.batch)
losses = fqe.train(self.batch)

# Make sure FQETorchModel.train() does not modify self.batch
check(tmp_batch, self.batch)

# Make sure FQE stopping criteria are respected
assert (
len(losses) == fqe.n_iters or losses[-1] < fqe.delta
), f"FQE.train() terminated early in {len(losses)} steps with final loss"
f"{losses[-1]} for n_iters: {fqe.n_iters} and delta: {fqe.delta}"

# Test fqe._compute_action_probs against "brute force" method
# of computing log_prob for each possible action individually
# using policy.compute_log_likelihoods
obs = torch.tensor(self.batch["obs"], device=fqe.device)
action_probs = fqe._compute_action_probs(obs)
action_probs = convert_to_numpy(action_probs)

tmp_probs = []
for act in range(fqe.policy.action_space.n):
tmp_actions = np.zeros_like(self.batch["actions"]) + act
log_probs = fqe.policy.compute_log_likelihoods(
actions=tmp_actions,
obs_batch=self.batch["obs"],
)
tmp_probs.append(torch.exp(log_probs))
tmp_probs = torch.stack(tmp_probs).transpose(0, 1)
tmp_probs = convert_to_numpy(tmp_probs)
check(action_probs, tmp_probs, decimals=3)

def test_multiple_inputs(self):
# TODO (Rohan138): Test with multiple input files
pass
Expand Down