Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N #29600

Merged
merged 10 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,13 @@ py_test(
srcs = ["models/tests/test_distributions.py"]
)

py_test(
name = "test_distributions_v2",
tags = ["team:rllib", "models"],
size = "medium",
srcs = ["models/tests/test_distributions_v2.py"]
)

py_test(
name = "test_lstms",
tags = ["team:rllib", "models"],
Expand Down
124 changes: 124 additions & 0 deletions rllib/models/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""This is the next version of action distribution base class."""
from typing import Tuple
import gym
import abc

from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.typing import TensorType, Union, ModelConfigDict


@ExperimentalAPI
class Distribution(abc.ABC):
"""The base class for distribution over a random variable.

Examples:
>>> model = ... # a model that outputs a vector of logits
>>> action_logits = model.forward(obs)
>>> action_dist = Distribution(action_logits)
>>> action = action_dist.sample()
>>> logp = action_dist.logp(action)
>>> kl = action_dist.kl(action_dist2)
>>> entropy = action_dist.entropy()

"""

@abc.abstractmethod
def sample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a sample from the distribution.

Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.

Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""

@abc.abstractmethod
def rsample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a re-parameterized sample from the action distribution.

If this method is implemented, we can take gradients of samples w.r.t. the
distribution parameters.

Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.

Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""

@abc.abstractmethod
def logp(self, value: TensorType, **kwargs) -> TensorType:
"""The log-likelihood of the distribution computed at `value`

Args:
value: The value to compute the log-likelihood at.
**kwargs: Forward compatibility placeholder.

Returns:
The log-likelihood of the value.
"""

@abc.abstractmethod
def kl(self, other: "Distribution", **kwargs) -> TensorType:
"""The KL-divergence between two distributions.

Args:
other: The other distribution.
**kwargs: Forward compatibility placeholder.

Returns:
The KL-divergence between the two distributions.
"""

@abc.abstractmethod
def entropy(self, **kwargs) -> TensorType:
"""The entropy of the distribution.

Args:
**kwargs: Forward compatibility placeholder.

Returns:
The entropy of the distribution.
"""

@staticmethod
@abc.abstractmethod
def required_model_output_shape(
space: gym.Space, model_config: ModelConfigDict
) -> Tuple[int, ...]:
"""Returns the required shape of an input parameter tensor for a
particular space and an optional dict of distribution-specific
options.

Let's have this method here just as a reminder to the next developer that this
was part of the old distribution classes that we may or may not keep depending
on how the catalog gets written.

Args:
space: The space this distribution will be used for,
whose shape attributes will be used to determine the required shape of
the input parameter tensor.
model_config: Model's config dict (as defined in catalog.py)

Returns:
size of the required input vector (minus leading batch dimension).
"""
236 changes: 236 additions & 0 deletions rllib/models/tests/test_distributions_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from copy import copy
import numpy as np
import unittest
import math

from ray.rllib.models.torch.torch_distributions import (
TorchCategorical,
TorchDiagGaussian,
TorchDeterministic,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import (
softmax,
SMALL_NUMBER,
LARGE_INTEGER,
)
from ray.rllib.utils.test_utils import check

torch, _ = try_import_torch()


def check_stability(dist_class, *, sample_input=None, constraints=None):
max_tries = 100
extreme_values = [
0.0,
float(LARGE_INTEGER),
-float(LARGE_INTEGER),
1.1e-34,
1.1e34,
-1.1e-34,
-1.1e34,
SMALL_NUMBER,
-SMALL_NUMBER,
]

input_kwargs = copy(sample_input)
for key, array in input_kwargs.items():
arr_sampled = np.random.choice(extreme_values, replace=True, size=array.shape)
input_kwargs[key] = torch.from_numpy(arr_sampled).float()

if constraints:
constraint = constraints.get(key, None)
if constraint:
if constraint == "positive_not_inf":
input_kwargs[key] = torch.minimum(
SMALL_NUMBER + torch.log(1 + torch.exp(input_kwargs[key])),
torch.tensor([LARGE_INTEGER]),
)
elif constraint == "probability":
input_kwargs[key] = torch.softmax(input_kwargs[key], dim=-1)

dist = dist_class(**input_kwargs)
for _ in range(max_tries):
sample = dist.sample()

assert not torch.isnan(sample).any()
assert torch.all(torch.isfinite(sample))

logp = dist.logp(sample)
assert not torch.isnan(logp).any()
assert torch.all(torch.isfinite(logp))


class TestDistributions(unittest.TestCase):
"""Tests ActionDistribution classes."""

@classmethod
def setUpClass(cls) -> None:
# Set seeds for deterministic tests (make sure we don't fail
# because of "bad" sampling).
np.random.seed(42)
torch.manual_seed(42)

def test_categorical(self):
batch_size = 10000
num_categories = 4
sample_shape = 2

# Create categorical distribution with n categories.
logits = np.random.randn(batch_size, num_categories)
probs = torch.from_numpy(softmax(logits)).float()
logits = torch.from_numpy(logits).float()

# check stability against skewed inputs
check_stability(TorchCategorical, sample_input={"logits": logits})
check_stability(
TorchCategorical,
sample_input={"probs": logits},
constraints={"probs": "probability"},
)

dist_with_logits = TorchCategorical(logits=logits)
dist_with_probs = TorchCategorical(probs=probs)

samples = dist_with_logits.sample(sample_shape=(sample_shape,))

# check shape of samples
self.assertEqual(
samples.shape,
(
sample_shape,
batch_size,
),
)
self.assertEqual(samples.dtype, torch.int64)
# check that none of the samples are nan
self.assertFalse(torch.isnan(samples).any())
# check that all samples are in the range of the number of categories
self.assertTrue((samples >= 0).all())
self.assertTrue((samples < num_categories).all())

# resample to remove the first batch dim
samples = dist_with_logits.sample()
# check that the two distributions are the same
check(dist_with_logits.logp(samples), dist_with_probs.logp(samples))

# check logp values
expected = probs.log().gather(dim=-1, index=samples.view(-1, 1)).view(-1)
check(dist_with_logits.logp(samples), expected)

# check entropy
expected = -(probs * probs.log()).sum(dim=-1)
check(dist_with_logits.entropy(), expected)

# check kl
probs2 = softmax(np.random.randn(batch_size, num_categories))
probs2 = torch.from_numpy(probs2).float()
dist2 = TorchCategorical(probs=probs2)
expected = (probs * (probs / probs2).log()).sum(dim=-1)
check(dist_with_probs.kl(dist2), expected)

# check rsample
self.assertRaises(NotImplementedError, dist_with_logits.rsample)

# test temperature
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20)
samples = dist_with_logits.sample()
# expected is armax of logits
expected = logits.argmax(dim=-1)
check(samples, expected)

def test_diag_gaussian(self):
batch_size = 128
ndim = 4
sample_shape = 100000

loc = np.random.randn(batch_size, ndim)
scale = np.exp(np.random.randn(batch_size, ndim))

loc_tens = torch.from_numpy(loc).float()
scale_tens = torch.from_numpy(scale).float()

dist = TorchDiagGaussian(loc=loc_tens, scale=scale_tens)
sample = dist.sample(sample_shape=(sample_shape,))

# check shape of samples
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim))
self.assertEqual(sample.dtype, torch.float32)
# check that none of the samples are nan
self.assertFalse(torch.isnan(sample).any())

# check that mean and std are approximately correct
check(sample.mean(0), loc, decimals=1)
check(sample.std(0), scale, decimals=1)

# check logp values
expected = (
-0.5 * ((sample - loc_tens) / scale_tens).pow(2).sum(-1)
+ -0.5 * ndim * math.log(2 * math.pi)
- scale_tens.log().sum(-1)
)
check(dist.logp(sample), expected)

# check entropy
expected = 0.5 * ndim * (
1 + torch.log(2 * torch.tensor([torch.pi]))
) + scale_tens.log().sum(-1)
check(dist.entropy(), expected)

# check kl
loc2 = torch.from_numpy(np.random.randn(batch_size, ndim)).float()
scale2 = torch.from_numpy(np.exp(np.random.randn(batch_size, ndim)))
dist2 = TorchDiagGaussian(loc=loc2, scale=scale2)
expected = (
scale2.log()
- scale_tens.log()
+ (scale_tens.pow(2) + (loc_tens - loc2).pow(2)) / (2 * scale2.pow(2))
- 0.5
).sum(-1)
check(dist.kl(dist2), expected, decimals=4)

# check rsample
loc_tens.requires_grad = True
scale_tens.requires_grad = True
dist = TorchDiagGaussian(loc=2 * loc_tens, scale=2 * scale_tens)
sample1 = dist.rsample()
sample2 = dist.sample()

self.assertRaises(
RuntimeError, lambda: sample2.mean().backward(retain_graph=True)
)
sample1.mean().backward(retain_graph=True)

# check stablity against skewed inputs
check_stability(
TorchDiagGaussian,
sample_input={"loc": loc_tens, "scale": scale_tens},
constraints={"scale": "positive_not_inf"},
)

def test_determinstic(self):
batch_size = 128
ndim = 4
sample_shape = 100000

loc = np.random.randn(batch_size, ndim)

loc_tens = torch.from_numpy(loc).float()

dist = TorchDeterministic(loc=loc_tens)
sample = dist.sample(sample_shape=(sample_shape,))
sample2 = dist.sample(sample_shape=(sample_shape,))
check(sample, sample2)

# check shape of samples
self.assertEqual(sample.shape, (sample_shape, batch_size, ndim))
self.assertEqual(sample.dtype, torch.float32)
# check that none of the samples are nan
self.assertFalse(torch.isnan(sample).any())


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading