-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N #29600
Changes from 2 commits
da8f916
d366847
f96534c
4bcad07
486f3d2
0a50805
185257e
7d12b65
8e86cb2
3e2933e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Tuple | ||
import gym | ||
import abc | ||
|
||
from ray.rllib.utils.annotations import ExperimentalAPI | ||
from ray.rllib.utils.typing import TensorType, Union, ModelConfigDict | ||
|
||
|
||
@ExperimentalAPI | ||
class ActionDistributionV2(abc.ABC): | ||
"""The policy action distribution of an agent. | ||
|
||
Args: | ||
inputs: input vector to define the distribution over. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this always a vector? Or could this also be a dict of tensors (multi-distribution) or a >1D tensor (multi-variate diag gaussian)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh this is a residue of the old stuff. Doesn't mean anything anymore. Removed. :) |
||
|
||
Examples: | ||
>>> model = ... # a model that outputs a vector of logits | ||
>>> action_logits = model.forward(obs) | ||
>>> action_dist = ActionDistribution(action_logits) | ||
>>> action = action_dist.sample() | ||
>>> logp = action_dist.logp(action) | ||
>>> kl = action_dist.kl(action_dist2) | ||
>>> entropy = action_dist.entropy() | ||
|
||
""" | ||
|
||
@abc.abstractmethod | ||
def sample( | ||
self, | ||
*, | ||
sample_shape: Tuple[int, ...] = None, | ||
return_logp: bool = False, | ||
**kwargs | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
"""Draw a sample from the action distribution. | ||
|
||
Args: | ||
sample_shape: The shape of the sample to draw. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we be more specific here? Doesn't the input (currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be addressed in the docstrings of each individual distribution. |
||
return_logp: Whether to return the logp of the sampled action. | ||
**kwargs: Forward compatibility placeholder. | ||
|
||
Returns: | ||
The sampled action. If return_logp is True, returns a tuple of the | ||
sampled action and its logp. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def rsample( | ||
self, | ||
*, | ||
sample_shape: Tuple[int, ...] = None, | ||
return_logp: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great choice to add this option to the signature! |
||
**kwargs | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
"""Draw a re-parameterized sample from the action distribution. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we quickly explain what the difference is to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep added. |
||
|
||
Args: | ||
sample_shape: The shape of the sample to draw. | ||
return_logp: Whether to return the logp of the sampled action. | ||
**kwargs: Forward compatibility placeholder. | ||
|
||
Returns: | ||
The sampled action. If return_logp is True, returns a tuple of the | ||
sampled action and its logp. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def logp(self, action: TensorType, **kwargs) -> TensorType: | ||
"""The log-likelihood of the action distribution. | ||
|
||
Args: | ||
action: The action to compute the log-likelihood for. | ||
**kwargs: Forward compatibility placeholder. | ||
|
||
Returns: | ||
The log-likelihood of the action. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def kl(self, other: "ActionDistributionV2", **kwargs) -> TensorType: | ||
"""The KL-divergence between two action distributions. | ||
|
||
Args: | ||
other: The other action distribution. | ||
**kwargs: Forward compatibility placeholder. | ||
|
||
Returns: | ||
The KL-divergence between the two action distributions. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def entropy(self, **kwargs) -> TensorType: | ||
"""The entropy of the action distribution. | ||
|
||
Args: | ||
**kwargs: Forward compatibility placeholder. | ||
|
||
Returns: | ||
The entropy of the action distribution. | ||
""" | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def required_model_output_shape( | ||
action_space: gym.Space, model_config: ModelConfigDict | ||
) -> Tuple[int, ...]: | ||
"""Returns the required shape of an input parameter tensor for a | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we give 2 examples here of how this method will be used (for Categorical and DiagGaussian)?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this will actually get used tbh. I have it here just as a reminder to the next developer that this was part of the old distribution that we may or may not keep depending on how the catalog gets written. added a comment to reflect this. |
||
particular action space and an optional dict of distribution-specific | ||
options. | ||
|
||
Args: | ||
action_space: The action space this distribution will be used for, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. action_space -> space There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
whose shape attributes will be used to determine the required shape of | ||
the input parameter tensor. | ||
model_config: Model's config dict (as defined in catalog.py) | ||
|
||
Returns: | ||
size of the required input vector (minus leading batch dimension). | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
from copy import copy | ||
import numpy as np | ||
import unittest | ||
|
||
from ray.rllib.models.torch.torch_action_dist_v2 import ( | ||
TorchCategorical, | ||
TorchDiagGaussian, | ||
TorchDeterministic, | ||
) | ||
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.numpy import ( | ||
softmax, | ||
SMALL_NUMBER, | ||
LARGE_INTEGER, | ||
) | ||
from ray.rllib.utils.test_utils import check | ||
|
||
torch, _ = try_import_torch() | ||
|
||
|
||
def check_stability(dist_class, *, sample_input=None, constraints=None): | ||
max_tries = 100 | ||
extreme_values = [ | ||
0.0, | ||
float(LARGE_INTEGER), | ||
-float(LARGE_INTEGER), | ||
1.1e-34, | ||
1.1e34, | ||
-1.1e-34, | ||
-1.1e34, | ||
SMALL_NUMBER, | ||
-SMALL_NUMBER, | ||
] | ||
|
||
input_kwargs = copy(sample_input) | ||
for key, array in input_kwargs.items(): | ||
arr_sampled = np.random.choice(extreme_values, replace=True, size=array.shape) | ||
input_kwargs[key] = torch.from_numpy(arr_sampled).float() | ||
|
||
if constraints: | ||
constraint = constraints.get(key, None) | ||
if constraint: | ||
if constraint == "positive_not_inf": | ||
input_kwargs[key] = torch.minimum( | ||
SMALL_NUMBER + torch.log(1 + torch.exp(input_kwargs[key])), | ||
torch.tensor([LARGE_INTEGER]), | ||
) | ||
elif constraint == "probability": | ||
input_kwargs[key] = torch.softmax(input_kwargs[key], dim=-1) | ||
|
||
dist = dist_class(**input_kwargs) | ||
for _ in range(max_tries): | ||
sample = dist.sample() | ||
|
||
assert not torch.isnan(sample).any() | ||
assert torch.all(torch.isfinite(sample)) | ||
|
||
logp = dist.logp(sample) | ||
assert not torch.isnan(logp).any() | ||
assert torch.all(torch.isfinite(logp)) | ||
|
||
|
||
class TestDistributions(unittest.TestCase): | ||
"""Tests ActionDistribution classes.""" | ||
|
||
@classmethod | ||
def setUpClass(cls) -> None: | ||
# Set seeds for deterministic tests (make sure we don't fail | ||
# because of "bad" sampling). | ||
np.random.seed(42) | ||
torch.manual_seed(42) | ||
|
||
def test_categorical(self): | ||
batch_size = 10000 | ||
num_categories = 4 | ||
sample_shape = 2 | ||
|
||
# Create categorical distribution with n categories. | ||
logits = np.random.randn(batch_size, num_categories) | ||
probs = torch.from_numpy(softmax(logits)).float() | ||
logits = torch.from_numpy(logits).float() | ||
|
||
# check stability against skewed inputs | ||
check_stability(TorchCategorical, sample_input={"logits": logits}) | ||
check_stability( | ||
TorchCategorical, | ||
sample_input={"probs": logits}, | ||
constraints={"probs": "probability"}, | ||
) | ||
|
||
dist_with_logits = TorchCategorical(logits=logits) | ||
dist_with_probs = TorchCategorical(probs=probs) | ||
|
||
samples = dist_with_logits.sample(sample_shape=(sample_shape,)) | ||
|
||
# check shape of samples | ||
self.assertEqual( | ||
samples.shape, | ||
( | ||
sample_shape, | ||
batch_size, | ||
), | ||
) | ||
self.assertEqual(samples.dtype, torch.int64) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(samples).any()) | ||
# check that all samples are in the range of the number of categories | ||
self.assertTrue((samples >= 0).all()) | ||
self.assertTrue((samples < num_categories).all()) | ||
|
||
# resample to remove the first batch dim | ||
samples = dist_with_logits.sample() | ||
# check that the two distributions are the same | ||
check(dist_with_logits.logp(samples), dist_with_probs.logp(samples)) | ||
|
||
# check logp values | ||
expected = probs.log().gather(dim=-1, index=samples.view(-1, 1)).view(-1) | ||
check(dist_with_logits.logp(samples), expected) | ||
|
||
# check entropy | ||
expected = -(probs * probs.log()).sum(dim=-1) | ||
check(dist_with_logits.entropy(), expected) | ||
|
||
# check kl | ||
probs2 = softmax(np.random.randn(batch_size, num_categories)) | ||
probs2 = torch.from_numpy(probs2).float() | ||
dist2 = TorchCategorical(probs=probs2) | ||
expected = (probs * (probs / probs2).log()).sum(dim=-1) | ||
check(dist_with_probs.kl(dist2), expected) | ||
|
||
# check rsample | ||
self.assertRaises(NotImplementedError, dist_with_logits.rsample) | ||
|
||
# test temperature | ||
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20) | ||
samples = dist_with_logits.sample() | ||
# expected is armax of logits | ||
expected = logits.argmax(dim=-1) | ||
check(samples, expected) | ||
|
||
def test_diag_gaussian(self): | ||
batch_size = 128 | ||
ndim = 4 | ||
sample_shape = 100000 | ||
|
||
loc = np.random.randn(batch_size, ndim) | ||
scale = np.exp(np.random.randn(batch_size, ndim)) | ||
|
||
loc_tens = torch.from_numpy(loc).float() | ||
scale_tens = torch.from_numpy(scale).float() | ||
|
||
dist = TorchDiagGaussian(loc=loc_tens, scale=scale_tens) | ||
sample = dist.sample(sample_shape=(sample_shape,)) | ||
|
||
# check shape of samples | ||
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) | ||
self.assertEqual(sample.dtype, torch.float32) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(sample).any()) | ||
|
||
# check that mean and std are approximately correct | ||
check(sample.mean(0), loc, decimals=1) | ||
check(sample.std(0), scale, decimals=1) | ||
|
||
# check logp values | ||
expected = ( | ||
-0.5 * ((sample - loc_tens) / scale_tens).pow(2).sum(-1) | ||
+ -0.5 * ndim * torch.log(2 * torch.tensor([torch.pi])) | ||
- scale_tens.log().sum(-1) | ||
) | ||
check(dist.logp(sample), expected) | ||
|
||
# check entropy | ||
expected = 0.5 * ndim * ( | ||
1 + torch.log(2 * torch.tensor([torch.pi])) | ||
) + scale_tens.log().sum(-1) | ||
check(dist.entropy(), expected) | ||
|
||
# check kl | ||
loc2 = torch.from_numpy(np.random.randn(batch_size, ndim)).float() | ||
scale2 = torch.from_numpy(np.exp(np.random.randn(batch_size, ndim))) | ||
dist2 = TorchDiagGaussian(loc=loc2, scale=scale2) | ||
expected = ( | ||
scale2.log() | ||
- scale_tens.log() | ||
+ (scale_tens.pow(2) + (loc_tens - loc2).pow(2)) / (2 * scale2.pow(2)) | ||
- 0.5 | ||
).sum(-1) | ||
check(dist.kl(dist2), expected, decimals=4) | ||
|
||
# check rsample | ||
loc_tens.requires_grad = True | ||
scale_tens.requires_grad = True | ||
dist = TorchDiagGaussian(loc=2 * loc_tens, scale=2 * scale_tens) | ||
sample1 = dist.rsample() | ||
sample2 = dist.sample() | ||
|
||
self.assertRaises( | ||
RuntimeError, lambda: sample2.mean().backward(retain_graph=True) | ||
) | ||
sample1.mean().backward(retain_graph=True) | ||
|
||
# check stablity against skewed inputs | ||
check_stability( | ||
TorchDiagGaussian, | ||
sample_input={"loc": loc_tens, "scale": scale_tens}, | ||
constraints={"scale": "positive_not_inf"}, | ||
) | ||
|
||
def test_determinstic(self): | ||
batch_size = 128 | ||
ndim = 4 | ||
sample_shape = 100000 | ||
|
||
loc = np.random.randn(batch_size, ndim) | ||
|
||
loc_tens = torch.from_numpy(loc).float() | ||
|
||
dist = TorchDeterministic(loc=loc_tens) | ||
sample = dist.sample(sample_shape=(sample_shape,)) | ||
sample2 = dist.sample(sample_shape=(sample_shape,)) | ||
check(sample, sample2) | ||
|
||
# check shape of samples | ||
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) | ||
self.assertEqual(sample.dtype, torch.float32) | ||
# check that none of the samples are nan | ||
self.assertFalse(torch.isnan(sample).any()) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should really take this opportunity and get rid of the restriction to actions. Can we just call this
Distribution
?This would then cover models that predict next states, rewards, actions, etc.. as well as serve inside variational autoencoders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. fixed.