-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N (…
…#29600) * created action_dist_v2 for RLModule examples Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
- Loading branch information
1 parent
241a02e
commit 3562cb4
Showing
4 changed files
with
617 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""This is the next version of action distribution base class.""" | ||
from typing import Tuple | ||
import gym | ||
import abc | ||
|
||
from ray.rllib.utils.annotations import ExperimentalAPI | ||
from ray.rllib.utils.typing import TensorType, Union, ModelConfigDict | ||
|
||
|
||
@ExperimentalAPI | ||
class Distribution(abc.ABC): | ||
"""The base class for distribution over a random variable. | ||
Examples: | ||
>>> model = ... # a model that outputs a vector of logits | ||
>>> action_logits = model.forward(obs) | ||
>>> action_dist = Distribution(action_logits) | ||
>>> action = action_dist.sample() | ||
>>> logp = action_dist.logp(action) | ||
>>> kl = action_dist.kl(action_dist2) | ||
>>> entropy = action_dist.entropy() | ||
""" | ||
|
||
@abc.abstractmethod | ||
def sample( | ||
self, | ||
*, | ||
sample_shape: Tuple[int, ...] = None, | ||
return_logp: bool = False, | ||
**kwargs | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
"""Draw a sample from the distribution. | ||
Args: | ||
sample_shape: The shape of the sample to draw. | ||
return_logp: Whether to return the logp of the sampled values. | ||
**kwargs: Forward compatibility placeholder. | ||
Returns: | ||
The sampled values. If return_logp is True, returns a tuple of the | ||
sampled values and its logp. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def rsample( | ||
self, | ||
*, | ||
sample_shape: Tuple[int, ...] = None, | ||
return_logp: bool = False, | ||
**kwargs | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
"""Draw a re-parameterized sample from the action distribution. | ||
If this method is implemented, we can take gradients of samples w.r.t. the | ||
distribution parameters. | ||
Args: | ||
sample_shape: The shape of the sample to draw. | ||
return_logp: Whether to return the logp of the sampled values. | ||
**kwargs: Forward compatibility placeholder. | ||
Returns: | ||
The sampled values. If return_logp is True, returns a tuple of the | ||
sampled values and its logp. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def logp(self, value: TensorType, **kwargs) -> TensorType: | ||
"""The log-likelihood of the distribution computed at `value` | ||
Args: | ||
value: The value to compute the log-likelihood at. | ||
**kwargs: Forward compatibility placeholder. | ||
Returns: | ||
The log-likelihood of the value. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def kl(self, other: "Distribution", **kwargs) -> TensorType: | ||
"""The KL-divergence between two distributions. | ||
Args: | ||
other: The other distribution. | ||
**kwargs: Forward compatibility placeholder. | ||
Returns: | ||
The KL-divergence between the two distributions. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def entropy(self, **kwargs) -> TensorType: | ||
"""The entropy of the distribution. | ||
Args: | ||
**kwargs: Forward compatibility placeholder. | ||
Returns: | ||
The entropy of the distribution. | ||
""" | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def required_model_output_shape( | ||
space: gym.Space, model_config: ModelConfigDict | ||
) -> Tuple[int, ...]: | ||
"""Returns the required shape of an input parameter tensor for a | ||
particular space and an optional dict of distribution-specific | ||
options. | ||
Let's have this method here just as a reminder to the next developer that this | ||
was part of the old distribution classes that we may or may not keep depending | ||
on how the catalog gets written. | ||
Args: | ||
space: The space this distribution will be used for, | ||
whose shape attributes will be used to determine the required shape of | ||
the input parameter tensor. | ||
model_config: Model's config dict (as defined in catalog.py) | ||
Returns: | ||
size of the required input vector (minus leading batch dimension). | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
from copy import copy | ||
import numpy as np | ||
import unittest | ||
import math | ||
|
||
from ray.rllib.models.torch.torch_distributions import ( | ||
TorchCategorical, | ||
TorchDiagGaussian, | ||
TorchDeterministic, | ||
) | ||
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.numpy import ( | ||
softmax, | ||
SMALL_NUMBER, | ||
LARGE_INTEGER, | ||
) | ||
from ray.rllib.utils.test_utils import check | ||
|
||
torch, _ = try_import_torch() | ||
|
||
|
||
def check_stability(dist_class, *, sample_input=None, constraints=None): | ||
max_tries = 100 | ||
extreme_values = [ | ||
0.0, | ||
float(LARGE_INTEGER), | ||
-float(LARGE_INTEGER), | ||
1.1e-34, | ||
1.1e34, | ||
-1.1e-34, | ||
-1.1e34, | ||
SMALL_NUMBER, | ||
-SMALL_NUMBER, | ||
] | ||
|
||
input_kwargs = copy(sample_input) | ||
for key, array in input_kwargs.items(): | ||
arr_sampled = np.random.choice(extreme_values, replace=True, size=array.shape) | ||
input_kwargs[key] = torch.from_numpy(arr_sampled).float() | ||
|
||
if constraints: | ||
constraint = constraints.get(key, None) | ||
if constraint: | ||
if constraint == "positive_not_inf": | ||
input_kwargs[key] = torch.minimum( | ||
SMALL_NUMBER + torch.log(1 + torch.exp(input_kwargs[key])), | ||
torch.tensor([LARGE_INTEGER]), | ||
) | ||
elif constraint == "probability": | ||
input_kwargs[key] = torch.softmax(input_kwargs[key], dim=-1) | ||
|
||
dist = dist_class(**input_kwargs) | ||
for _ in range(max_tries): | ||
sample = dist.sample() | ||
|
||
assert not torch.isnan(sample).any() | ||
assert torch.all(torch.isfinite(sample)) | ||
|
||
logp = dist.logp(sample) | ||
assert not torch.isnan(logp).any() | ||
assert torch.all(torch.isfinite(logp)) | ||
|
||
|
||
class TestDistributions(unittest.TestCase): | ||
"""Tests ActionDistribution classes.""" | ||
|
||
@classmethod | ||
def setUpClass(cls) -> None: | ||
# Set seeds for deterministic tests (make sure we don't fail | ||
# because of "bad" sampling). | ||
np.random.seed(42) | ||
torch.manual_seed(42) | ||
|
||
def test_categorical(self): | ||
batch_size = 10000 | ||
num_categories = 4 | ||
sample_shape = 2 | ||
|
||
# Create categorical distribution with n categories. | ||
logits = np.random.randn(batch_size, num_categories) | ||
probs = torch.from_numpy(softmax(logits)).float() | ||
logits = torch.from_numpy(logits).float() | ||
|
||
# check stability against skewed inputs | ||
check_stability(TorchCategorical, sample_input={"logits": logits}) | ||
check_stability( | ||
TorchCategorical, | ||
sample_input={"probs": logits}, | ||
constraints={"probs": "probability"}, | ||
) | ||
|
||
dist_with_logits = TorchCategorical(logits=logits) | ||
dist_with_probs = TorchCategorical(probs=probs) | ||
|
||
samples = dist_with_logits.sample(sample_shape=(sample_shape,)) | ||
|
||
# check shape of samples | ||
self.assertEqual( | ||
samples.shape, | ||
( | ||
sample_shape, | ||
batch_size, | ||
), | ||
) | ||
self.assertEqual(samples.dtype, torch.int64) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(samples).any()) | ||
# check that all samples are in the range of the number of categories | ||
self.assertTrue((samples >= 0).all()) | ||
self.assertTrue((samples < num_categories).all()) | ||
|
||
# resample to remove the first batch dim | ||
samples = dist_with_logits.sample() | ||
# check that the two distributions are the same | ||
check(dist_with_logits.logp(samples), dist_with_probs.logp(samples)) | ||
|
||
# check logp values | ||
expected = probs.log().gather(dim=-1, index=samples.view(-1, 1)).view(-1) | ||
check(dist_with_logits.logp(samples), expected) | ||
|
||
# check entropy | ||
expected = -(probs * probs.log()).sum(dim=-1) | ||
check(dist_with_logits.entropy(), expected) | ||
|
||
# check kl | ||
probs2 = softmax(np.random.randn(batch_size, num_categories)) | ||
probs2 = torch.from_numpy(probs2).float() | ||
dist2 = TorchCategorical(probs=probs2) | ||
expected = (probs * (probs / probs2).log()).sum(dim=-1) | ||
check(dist_with_probs.kl(dist2), expected) | ||
|
||
# check rsample | ||
self.assertRaises(NotImplementedError, dist_with_logits.rsample) | ||
|
||
# test temperature | ||
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20) | ||
samples = dist_with_logits.sample() | ||
# expected is armax of logits | ||
expected = logits.argmax(dim=-1) | ||
check(samples, expected) | ||
|
||
def test_diag_gaussian(self): | ||
batch_size = 128 | ||
ndim = 4 | ||
sample_shape = 100000 | ||
|
||
loc = np.random.randn(batch_size, ndim) | ||
scale = np.exp(np.random.randn(batch_size, ndim)) | ||
|
||
loc_tens = torch.from_numpy(loc).float() | ||
scale_tens = torch.from_numpy(scale).float() | ||
|
||
dist = TorchDiagGaussian(loc=loc_tens, scale=scale_tens) | ||
sample = dist.sample(sample_shape=(sample_shape,)) | ||
|
||
# check shape of samples | ||
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) | ||
self.assertEqual(sample.dtype, torch.float32) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(sample).any()) | ||
|
||
# check that mean and std are approximately correct | ||
check(sample.mean(0), loc, decimals=1) | ||
check(sample.std(0), scale, decimals=1) | ||
|
||
# check logp values | ||
expected = ( | ||
-0.5 * ((sample - loc_tens) / scale_tens).pow(2).sum(-1) | ||
+ -0.5 * ndim * math.log(2 * math.pi) | ||
- scale_tens.log().sum(-1) | ||
) | ||
check(dist.logp(sample), expected) | ||
|
||
# check entropy | ||
expected = 0.5 * ndim * (1 + math.log(2 * math.pi)) + scale_tens.log().sum(-1) | ||
check(dist.entropy(), expected) | ||
|
||
# check kl | ||
loc2 = torch.from_numpy(np.random.randn(batch_size, ndim)).float() | ||
scale2 = torch.from_numpy(np.exp(np.random.randn(batch_size, ndim))) | ||
dist2 = TorchDiagGaussian(loc=loc2, scale=scale2) | ||
expected = ( | ||
scale2.log() | ||
- scale_tens.log() | ||
+ (scale_tens.pow(2) + (loc_tens - loc2).pow(2)) / (2 * scale2.pow(2)) | ||
- 0.5 | ||
).sum(-1) | ||
check(dist.kl(dist2), expected, decimals=4) | ||
|
||
# check rsample | ||
loc_tens.requires_grad = True | ||
scale_tens.requires_grad = True | ||
dist = TorchDiagGaussian(loc=2 * loc_tens, scale=2 * scale_tens) | ||
sample1 = dist.rsample() | ||
sample2 = dist.sample() | ||
|
||
self.assertRaises( | ||
RuntimeError, lambda: sample2.mean().backward(retain_graph=True) | ||
) | ||
sample1.mean().backward(retain_graph=True) | ||
|
||
# check stablity against skewed inputs | ||
check_stability( | ||
TorchDiagGaussian, | ||
sample_input={"loc": loc_tens, "scale": scale_tens}, | ||
constraints={"scale": "positive_not_inf"}, | ||
) | ||
|
||
def test_determinstic(self): | ||
batch_size = 128 | ||
ndim = 4 | ||
sample_shape = 100000 | ||
|
||
loc = np.random.randn(batch_size, ndim) | ||
|
||
loc_tens = torch.from_numpy(loc).float() | ||
|
||
dist = TorchDeterministic(loc=loc_tens) | ||
sample = dist.sample(sample_shape=(sample_shape,)) | ||
sample2 = dist.sample(sample_shape=(sample_shape,)) | ||
check(sample, sample2) | ||
|
||
# check shape of samples | ||
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) | ||
self.assertEqual(sample.dtype, torch.float32) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(sample).any()) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
Oops, something went wrong.