Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N #29600

Merged
merged 10 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,13 @@ py_test(
srcs = ["models/tests/test_distributions.py"]
)

py_test(
name = "test_distributions_v2",
tags = ["team:rllib", "models"],
size = "medium",
srcs = ["models/tests/test_distributions_v2.py"]
)

py_test(
name = "test_lstms",
tags = ["team:rllib", "models"],
Expand Down
119 changes: 119 additions & 0 deletions rllib/models/action_dist_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Tuple
import gym
import abc

from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.typing import TensorType, Union, ModelConfigDict


@ExperimentalAPI
class ActionDistributionV2(abc.ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really take this opportunity and get rid of the restriction to actions. Can we just call this Distribution?
This would then cover models that predict next states, rewards, actions, etc.. as well as serve inside variational autoencoders.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. fixed.

"""The policy action distribution of an agent.

Args:
inputs: input vector to define the distribution over.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always a vector? Or could this also be a dict of tensors (multi-distribution) or a >1D tensor (multi-variate diag gaussian)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a residue of the old stuff. Doesn't mean anything anymore. Removed. :)


Examples:
>>> model = ... # a model that outputs a vector of logits
>>> action_logits = model.forward(obs)
>>> action_dist = ActionDistribution(action_logits)
>>> action = action_dist.sample()
>>> logp = action_dist.logp(action)
>>> kl = action_dist.kl(action_dist2)
>>> entropy = action_dist.entropy()

"""

@abc.abstractmethod
def sample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a sample from the action distribution.

Args:
sample_shape: The shape of the sample to draw.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more specific here? Doesn't the input (currently self.inputs) already determine the batch/time dimensions and the rest is "fixed"? If this is not the case, can we add examples to this docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be addressed in the docstrings of each individual distribution.

return_logp: Whether to return the logp of the sampled action.
**kwargs: Forward compatibility placeholder.

Returns:
The sampled action. If return_logp is True, returns a tuple of the
sampled action and its logp.
"""

@abc.abstractmethod
def rsample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
Copy link
Contributor

@sven1977 sven1977 Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great choice to add this option to the signature!

**kwargs
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a re-parameterized sample from the action distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we quickly explain what the difference is to sample? Basically that rsample is backprop'able and sample is not (iiuc)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep added.


Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled action.
**kwargs: Forward compatibility placeholder.

Returns:
The sampled action. If return_logp is True, returns a tuple of the
sampled action and its logp.
"""

@abc.abstractmethod
def logp(self, action: TensorType, **kwargs) -> TensorType:
"""The log-likelihood of the action distribution.

Args:
action: The action to compute the log-likelihood for.
**kwargs: Forward compatibility placeholder.

Returns:
The log-likelihood of the action.
"""

@abc.abstractmethod
def kl(self, other: "ActionDistributionV2", **kwargs) -> TensorType:
"""The KL-divergence between two action distributions.

Args:
other: The other action distribution.
**kwargs: Forward compatibility placeholder.

Returns:
The KL-divergence between the two action distributions.
"""

@abc.abstractmethod
def entropy(self, **kwargs) -> TensorType:
"""The entropy of the action distribution.

Args:
**kwargs: Forward compatibility placeholder.

Returns:
The entropy of the action distribution.
"""

@staticmethod
@abc.abstractmethod
def required_model_output_shape(
action_space: gym.Space, model_config: ModelConfigDict
) -> Tuple[int, ...]:
"""Returns the required shape of an input parameter tensor for a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we give 2 examples here of how this method will be used (for Categorical and DiagGaussian)?

  • Does the shape include the batch dim (time dim, etc..)?
  • What are typical cases, where the model config would be used and how would it be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this will actually get used tbh. I have it here just as a reminder to the next developer that this was part of the old distribution that we may or may not keep depending on how the catalog gets written. added a comment to reflect this.

particular action space and an optional dict of distribution-specific
options.

Args:
action_space: The action space this distribution will be used for,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

action_space -> space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

whose shape attributes will be used to determine the required shape of
the input parameter tensor.
model_config: Model's config dict (as defined in catalog.py)

Returns:
size of the required input vector (minus leading batch dimension).
"""
235 changes: 235 additions & 0 deletions rllib/models/tests/test_distributions_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from copy import copy
import numpy as np
import unittest

from ray.rllib.models.torch.torch_action_dist_v2 import (
TorchCategorical,
TorchDiagGaussian,
TorchDeterministic,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import (
softmax,
SMALL_NUMBER,
LARGE_INTEGER,
)
from ray.rllib.utils.test_utils import check

torch, _ = try_import_torch()


def check_stability(dist_class, *, sample_input=None, constraints=None):
max_tries = 100
extreme_values = [
0.0,
float(LARGE_INTEGER),
-float(LARGE_INTEGER),
1.1e-34,
1.1e34,
-1.1e-34,
-1.1e34,
SMALL_NUMBER,
-SMALL_NUMBER,
]

input_kwargs = copy(sample_input)
for key, array in input_kwargs.items():
arr_sampled = np.random.choice(extreme_values, replace=True, size=array.shape)
input_kwargs[key] = torch.from_numpy(arr_sampled).float()

if constraints:
constraint = constraints.get(key, None)
if constraint:
if constraint == "positive_not_inf":
input_kwargs[key] = torch.minimum(
SMALL_NUMBER + torch.log(1 + torch.exp(input_kwargs[key])),
torch.tensor([LARGE_INTEGER]),
)
elif constraint == "probability":
input_kwargs[key] = torch.softmax(input_kwargs[key], dim=-1)

dist = dist_class(**input_kwargs)
for _ in range(max_tries):
sample = dist.sample()

assert not torch.isnan(sample).any()
assert torch.all(torch.isfinite(sample))

logp = dist.logp(sample)
assert not torch.isnan(logp).any()
assert torch.all(torch.isfinite(logp))


class TestDistributions(unittest.TestCase):
"""Tests ActionDistribution classes."""

@classmethod
def setUpClass(cls) -> None:
# Set seeds for deterministic tests (make sure we don't fail
# because of "bad" sampling).
np.random.seed(42)
torch.manual_seed(42)

def test_categorical(self):
batch_size = 10000
num_categories = 4
sample_shape = 2

# Create categorical distribution with n categories.
logits = np.random.randn(batch_size, num_categories)
probs = torch.from_numpy(softmax(logits)).float()
logits = torch.from_numpy(logits).float()

# check stability against skewed inputs
check_stability(TorchCategorical, sample_input={"logits": logits})
check_stability(
TorchCategorical,
sample_input={"probs": logits},
constraints={"probs": "probability"},
)

dist_with_logits = TorchCategorical(logits=logits)
dist_with_probs = TorchCategorical(probs=probs)

samples = dist_with_logits.sample(sample_shape=(sample_shape,))

# check shape of samples
self.assertEqual(
samples.shape,
(
sample_shape,
batch_size,
),
)
self.assertEqual(samples.dtype, torch.int64)
# check that none of the samples are nan
self.assertFalse(torch.isnan(samples).any())
# check that all samples are in the range of the number of categories
self.assertTrue((samples >= 0).all())
self.assertTrue((samples < num_categories).all())

# resample to remove the first batch dim
samples = dist_with_logits.sample()
# check that the two distributions are the same
check(dist_with_logits.logp(samples), dist_with_probs.logp(samples))

# check logp values
expected = probs.log().gather(dim=-1, index=samples.view(-1, 1)).view(-1)
check(dist_with_logits.logp(samples), expected)

# check entropy
expected = -(probs * probs.log()).sum(dim=-1)
check(dist_with_logits.entropy(), expected)

# check kl
probs2 = softmax(np.random.randn(batch_size, num_categories))
probs2 = torch.from_numpy(probs2).float()
dist2 = TorchCategorical(probs=probs2)
expected = (probs * (probs / probs2).log()).sum(dim=-1)
check(dist_with_probs.kl(dist2), expected)

# check rsample
self.assertRaises(NotImplementedError, dist_with_logits.rsample)

# test temperature
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20)
samples = dist_with_logits.sample()
# expected is armax of logits
expected = logits.argmax(dim=-1)
check(samples, expected)

def test_diag_gaussian(self):
batch_size = 128
ndim = 4
sample_shape = 100000

loc = np.random.randn(batch_size, ndim)
scale = np.exp(np.random.randn(batch_size, ndim))

loc_tens = torch.from_numpy(loc).float()
scale_tens = torch.from_numpy(scale).float()

dist = TorchDiagGaussian(loc=loc_tens, scale=scale_tens)
sample = dist.sample(sample_shape=(sample_shape,))

# check shape of samples
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim))
self.assertEqual(sample.dtype, torch.float32)
# check that none of the samples are nan
self.assertFalse(torch.isnan(sample).any())

# check that mean and std are approximately correct
check(sample.mean(0), loc, decimals=1)
check(sample.std(0), scale, decimals=1)

# check logp values
expected = (
-0.5 * ((sample - loc_tens) / scale_tens).pow(2).sum(-1)
+ -0.5 * ndim * torch.log(2 * torch.tensor([torch.pi]))
- scale_tens.log().sum(-1)
)
check(dist.logp(sample), expected)

# check entropy
expected = 0.5 * ndim * (
1 + torch.log(2 * torch.tensor([torch.pi]))
) + scale_tens.log().sum(-1)
check(dist.entropy(), expected)

# check kl
loc2 = torch.from_numpy(np.random.randn(batch_size, ndim)).float()
scale2 = torch.from_numpy(np.exp(np.random.randn(batch_size, ndim)))
dist2 = TorchDiagGaussian(loc=loc2, scale=scale2)
expected = (
scale2.log()
- scale_tens.log()
+ (scale_tens.pow(2) + (loc_tens - loc2).pow(2)) / (2 * scale2.pow(2))
- 0.5
).sum(-1)
check(dist.kl(dist2), expected, decimals=4)

# check rsample
loc_tens.requires_grad = True
scale_tens.requires_grad = True
dist = TorchDiagGaussian(loc=2 * loc_tens, scale=2 * scale_tens)
sample1 = dist.rsample()
sample2 = dist.sample()

self.assertRaises(
RuntimeError, lambda: sample2.mean().backward(retain_graph=True)
)
sample1.mean().backward(retain_graph=True)

# check stablity against skewed inputs
check_stability(
TorchDiagGaussian,
sample_input={"loc": loc_tens, "scale": scale_tens},
constraints={"scale": "positive_not_inf"},
)

def test_determinstic(self):
batch_size = 128
ndim = 4
sample_shape = 100000

loc = np.random.randn(batch_size, ndim)

loc_tens = torch.from_numpy(loc).float()

dist = TorchDeterministic(loc=loc_tens)
sample = dist.sample(sample_shape=(sample_shape,))
sample2 = dist.sample(sample_shape=(sample_shape,))
check(sample, sample2)

# check shape of samples
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim))
self.assertEqual(sample.dtype, torch.float32)
# check that none of the samples are nan
self.assertFalse(torch.isnan(sample).any())


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading