diff --git a/rllib/BUILD b/rllib/BUILD index 8b89f468ec54..42fcb8442f4f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1772,6 +1772,13 @@ py_test( srcs = ["models/tests/test_distributions.py"] ) +py_test( + name = "test_distributions_v2", + tags = ["team:rllib", "models"], + size = "medium", + srcs = ["models/tests/test_distributions_v2.py"] +) + py_test( name = "test_lstms", tags = ["team:rllib", "models"], diff --git a/rllib/models/distributions.py b/rllib/models/distributions.py new file mode 100644 index 000000000000..4e1fb29c673c --- /dev/null +++ b/rllib/models/distributions.py @@ -0,0 +1,124 @@ +"""This is the next version of action distribution base class.""" +from typing import Tuple +import gym +import abc + +from ray.rllib.utils.annotations import ExperimentalAPI +from ray.rllib.utils.typing import TensorType, Union, ModelConfigDict + + +@ExperimentalAPI +class Distribution(abc.ABC): + """The base class for distribution over a random variable. + + Examples: + >>> model = ... # a model that outputs a vector of logits + >>> action_logits = model.forward(obs) + >>> action_dist = Distribution(action_logits) + >>> action = action_dist.sample() + >>> logp = action_dist.logp(action) + >>> kl = action_dist.kl(action_dist2) + >>> entropy = action_dist.entropy() + + """ + + @abc.abstractmethod + def sample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + """Draw a sample from the distribution. + + Args: + sample_shape: The shape of the sample to draw. + return_logp: Whether to return the logp of the sampled values. + **kwargs: Forward compatibility placeholder. + + Returns: + The sampled values. If return_logp is True, returns a tuple of the + sampled values and its logp. + """ + + @abc.abstractmethod + def rsample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + """Draw a re-parameterized sample from the action distribution. + + If this method is implemented, we can take gradients of samples w.r.t. the + distribution parameters. + + Args: + sample_shape: The shape of the sample to draw. + return_logp: Whether to return the logp of the sampled values. + **kwargs: Forward compatibility placeholder. + + Returns: + The sampled values. If return_logp is True, returns a tuple of the + sampled values and its logp. + """ + + @abc.abstractmethod + def logp(self, value: TensorType, **kwargs) -> TensorType: + """The log-likelihood of the distribution computed at `value` + + Args: + value: The value to compute the log-likelihood at. + **kwargs: Forward compatibility placeholder. + + Returns: + The log-likelihood of the value. + """ + + @abc.abstractmethod + def kl(self, other: "Distribution", **kwargs) -> TensorType: + """The KL-divergence between two distributions. + + Args: + other: The other distribution. + **kwargs: Forward compatibility placeholder. + + Returns: + The KL-divergence between the two distributions. + """ + + @abc.abstractmethod + def entropy(self, **kwargs) -> TensorType: + """The entropy of the distribution. + + Args: + **kwargs: Forward compatibility placeholder. + + Returns: + The entropy of the distribution. + """ + + @staticmethod + @abc.abstractmethod + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + """Returns the required shape of an input parameter tensor for a + particular space and an optional dict of distribution-specific + options. + + Let's have this method here just as a reminder to the next developer that this + was part of the old distribution classes that we may or may not keep depending + on how the catalog gets written. + + Args: + space: The space this distribution will be used for, + whose shape attributes will be used to determine the required shape of + the input parameter tensor. + model_config: Model's config dict (as defined in catalog.py) + + Returns: + size of the required input vector (minus leading batch dimension). + """ diff --git a/rllib/models/tests/test_distributions_v2.py b/rllib/models/tests/test_distributions_v2.py new file mode 100644 index 000000000000..98f550b7003c --- /dev/null +++ b/rllib/models/tests/test_distributions_v2.py @@ -0,0 +1,234 @@ +from copy import copy +import numpy as np +import unittest +import math + +from ray.rllib.models.torch.torch_distributions import ( + TorchCategorical, + TorchDiagGaussian, + TorchDeterministic, +) +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import ( + softmax, + SMALL_NUMBER, + LARGE_INTEGER, +) +from ray.rllib.utils.test_utils import check + +torch, _ = try_import_torch() + + +def check_stability(dist_class, *, sample_input=None, constraints=None): + max_tries = 100 + extreme_values = [ + 0.0, + float(LARGE_INTEGER), + -float(LARGE_INTEGER), + 1.1e-34, + 1.1e34, + -1.1e-34, + -1.1e34, + SMALL_NUMBER, + -SMALL_NUMBER, + ] + + input_kwargs = copy(sample_input) + for key, array in input_kwargs.items(): + arr_sampled = np.random.choice(extreme_values, replace=True, size=array.shape) + input_kwargs[key] = torch.from_numpy(arr_sampled).float() + + if constraints: + constraint = constraints.get(key, None) + if constraint: + if constraint == "positive_not_inf": + input_kwargs[key] = torch.minimum( + SMALL_NUMBER + torch.log(1 + torch.exp(input_kwargs[key])), + torch.tensor([LARGE_INTEGER]), + ) + elif constraint == "probability": + input_kwargs[key] = torch.softmax(input_kwargs[key], dim=-1) + + dist = dist_class(**input_kwargs) + for _ in range(max_tries): + sample = dist.sample() + + assert not torch.isnan(sample).any() + assert torch.all(torch.isfinite(sample)) + + logp = dist.logp(sample) + assert not torch.isnan(logp).any() + assert torch.all(torch.isfinite(logp)) + + +class TestDistributions(unittest.TestCase): + """Tests ActionDistribution classes.""" + + @classmethod + def setUpClass(cls) -> None: + # Set seeds for deterministic tests (make sure we don't fail + # because of "bad" sampling). + np.random.seed(42) + torch.manual_seed(42) + + def test_categorical(self): + batch_size = 10000 + num_categories = 4 + sample_shape = 2 + + # Create categorical distribution with n categories. + logits = np.random.randn(batch_size, num_categories) + probs = torch.from_numpy(softmax(logits)).float() + logits = torch.from_numpy(logits).float() + + # check stability against skewed inputs + check_stability(TorchCategorical, sample_input={"logits": logits}) + check_stability( + TorchCategorical, + sample_input={"probs": logits}, + constraints={"probs": "probability"}, + ) + + dist_with_logits = TorchCategorical(logits=logits) + dist_with_probs = TorchCategorical(probs=probs) + + samples = dist_with_logits.sample(sample_shape=(sample_shape,)) + + # check shape of samples + self.assertEqual( + samples.shape, + ( + sample_shape, + batch_size, + ), + ) + self.assertEqual(samples.dtype, torch.int64) + # check that none of the samples are nan + self.assertFalse(torch.isnan(samples).any()) + # check that all samples are in the range of the number of categories + self.assertTrue((samples >= 0).all()) + self.assertTrue((samples < num_categories).all()) + + # resample to remove the first batch dim + samples = dist_with_logits.sample() + # check that the two distributions are the same + check(dist_with_logits.logp(samples), dist_with_probs.logp(samples)) + + # check logp values + expected = probs.log().gather(dim=-1, index=samples.view(-1, 1)).view(-1) + check(dist_with_logits.logp(samples), expected) + + # check entropy + expected = -(probs * probs.log()).sum(dim=-1) + check(dist_with_logits.entropy(), expected) + + # check kl + probs2 = softmax(np.random.randn(batch_size, num_categories)) + probs2 = torch.from_numpy(probs2).float() + dist2 = TorchCategorical(probs=probs2) + expected = (probs * (probs / probs2).log()).sum(dim=-1) + check(dist_with_probs.kl(dist2), expected) + + # check rsample + self.assertRaises(NotImplementedError, dist_with_logits.rsample) + + # test temperature + dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20) + samples = dist_with_logits.sample() + # expected is armax of logits + expected = logits.argmax(dim=-1) + check(samples, expected) + + def test_diag_gaussian(self): + batch_size = 128 + ndim = 4 + sample_shape = 100000 + + loc = np.random.randn(batch_size, ndim) + scale = np.exp(np.random.randn(batch_size, ndim)) + + loc_tens = torch.from_numpy(loc).float() + scale_tens = torch.from_numpy(scale).float() + + dist = TorchDiagGaussian(loc=loc_tens, scale=scale_tens) + sample = dist.sample(sample_shape=(sample_shape,)) + + # check shape of samples + self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) + self.assertEqual(sample.dtype, torch.float32) + # check that none of the samples are nan + self.assertFalse(torch.isnan(sample).any()) + + # check that mean and std are approximately correct + check(sample.mean(0), loc, decimals=1) + check(sample.std(0), scale, decimals=1) + + # check logp values + expected = ( + -0.5 * ((sample - loc_tens) / scale_tens).pow(2).sum(-1) + + -0.5 * ndim * math.log(2 * math.pi) + - scale_tens.log().sum(-1) + ) + check(dist.logp(sample), expected) + + # check entropy + expected = 0.5 * ndim * (1 + math.log(2 * math.pi)) + scale_tens.log().sum(-1) + check(dist.entropy(), expected) + + # check kl + loc2 = torch.from_numpy(np.random.randn(batch_size, ndim)).float() + scale2 = torch.from_numpy(np.exp(np.random.randn(batch_size, ndim))) + dist2 = TorchDiagGaussian(loc=loc2, scale=scale2) + expected = ( + scale2.log() + - scale_tens.log() + + (scale_tens.pow(2) + (loc_tens - loc2).pow(2)) / (2 * scale2.pow(2)) + - 0.5 + ).sum(-1) + check(dist.kl(dist2), expected, decimals=4) + + # check rsample + loc_tens.requires_grad = True + scale_tens.requires_grad = True + dist = TorchDiagGaussian(loc=2 * loc_tens, scale=2 * scale_tens) + sample1 = dist.rsample() + sample2 = dist.sample() + + self.assertRaises( + RuntimeError, lambda: sample2.mean().backward(retain_graph=True) + ) + sample1.mean().backward(retain_graph=True) + + # check stablity against skewed inputs + check_stability( + TorchDiagGaussian, + sample_input={"loc": loc_tens, "scale": scale_tens}, + constraints={"scale": "positive_not_inf"}, + ) + + def test_determinstic(self): + batch_size = 128 + ndim = 4 + sample_shape = 100000 + + loc = np.random.randn(batch_size, ndim) + + loc_tens = torch.from_numpy(loc).float() + + dist = TorchDeterministic(loc=loc_tens) + sample = dist.sample(sample_shape=(sample_shape,)) + sample2 = dist.sample(sample_shape=(sample_shape,)) + check(sample, sample2) + + # check shape of samples + self.assertEqual(sample.shape, (sample_shape, batch_size, ndim)) + self.assertEqual(sample.dtype, torch.float32) + # check that none of the samples are nan + self.assertFalse(torch.isnan(sample).any()) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py new file mode 100644 index 000000000000..c72dc082fed7 --- /dev/null +++ b/rllib/models/torch/torch_distributions.py @@ -0,0 +1,252 @@ +"""The main difference between this and the old ActionDistribution is that this one +has more explicit input args. So that the input format does not have to be guessed from +the code. This matches the design pattern of torch distribution which developers may +already be familiar with. +""" +import gym +import numpy as np +from typing import Optional +import abc + + +from ray.rllib.models.distributions import Distribution +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType, Union, Tuple, ModelConfigDict + +torch, nn = try_import_torch() + + +@DeveloperAPI +class TorchDistribution(Distribution, abc.ABC): + """Wrapper class for torch.distributions.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self.dist = self._get_torch_distribution(*args, **kwargs) + + @abc.abstractmethod + def _get_torch_distribution( + self, *args, **kwargs + ) -> torch.distributions.Distribution: + """Returns the torch.distributions.Distribution object to use.""" + + @override(Distribution) + def logp(self, value: TensorType, **kwargs) -> TensorType: + return self.dist.log_prob(value, **kwargs) + + @override(Distribution) + def entropy(self) -> TensorType: + return self.dist.entropy() + + @override(Distribution) + def kl(self, other: "Distribution") -> TensorType: + return torch.distributions.kl.kl_divergence(self.dist, other.dist) + + @override(Distribution) + def sample( + self, *, sample_shape=torch.Size(), return_logp: bool = False + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + sample = self.dist.sample(sample_shape) + if return_logp: + return sample, self.logp(sample) + return sample + + @override(Distribution) + def rsample( + self, *, sample_shape=torch.Size(), return_logp: bool = False + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + rsample = self.dist.rsample(sample_shape) + if return_logp: + return rsample, self.logp(rsample) + return rsample + + +@DeveloperAPI +class TorchCategorical(TorchDistribution): + """Wrapper class for PyTorch Categorical distribution. + + Creates a categorical distribution parameterized by either :attr:`probs` or + :attr:`logits` (but not both). + + Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is + ``probs.size(-1)``. + + If `probs` is 1-dimensional with length-`K`, each element is the relative + probability of sampling the class at that index. + + If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. + + Example:: + >>> m = TorchCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample(sample_shape=(2,)) # equal probability of 0, 1, 2, 3 + tensor([3, 4]) + + Args: + probs: The probablities of each event. + logits: Event log probabilities (unnormalized) + temperature: In case of using logits, this parameter can be used to determine + the sharpness of the distribution. i.e. + ``probs = softmax(logits / temperature)``. The temperature must be strictly + positive. A low value (e.g. 1e-10) will result in argmax sampling while a + larger value will result in uniform sampling. + """ + + def __init__( + self, + probs: torch.Tensor = None, + logits: torch.Tensor = None, + temperature: float = 1.0, + ) -> None: + super().__init__(probs=probs, logits=logits, temperature=temperature) + + @override(TorchDistribution) + def _get_torch_distribution( + self, + probs: torch.Tensor = None, + logits: torch.Tensor = None, + temperature: float = 1.0, + ) -> torch.distributions.Distribution: + if logits is not None: + assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" + logits /= temperature + return torch.distributions.categorical.Categorical(probs, logits) + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + return (space.n,) + + +@DeveloperAPI +class TorchDiagGaussian(TorchDistribution): + """Wrapper class for PyTorch Normal distribution. + + Creates a normal distribution parameterized by :attr:`loc` and :attr:`scale`. In + case of multi-dimensional distribution, the variance is assumed to be diagonal. + + Example:: + + >>> m = Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0, 1.0])) + >>> m.sample(sample_shape=(2,)) # 2d normal dist with loc=0 and scale=1 + tensor([[ 0.1046, -0.6120], [ 0.234, 0.556]]) + + >>> # scale is None + >>> m = Normal(loc=torch.tensor([0.0, 1.0])) + >>> m.sample(sample_shape=(2,)) # normally distributed with loc=0 and scale=1 + tensor([0.1046, 0.6120]) + + + Args: + loc: mean of the distribution (often referred to as mu). If scale is None, the + second half of the `loc` will be used as the log of scale. + scale: standard deviation of the distribution (often referred to as sigma). + Has to be positive. + """ + + @override(Distribution) + def __init__( + self, + loc: Union[float, torch.Tensor], + scale: Optional[Union[float, torch.Tensor]] = None, + ): + super().__init__(loc=loc, scale=scale) + + def _get_torch_distribution( + self, loc, scale=None + ) -> torch.distributions.Distribution: + if scale is None: + loc, log_std = torch.chunk(self.inputs, 2, dim=1) + scale = torch.exp(log_std) + return torch.distributions.normal.Normal(loc, scale) + + @override(TorchDistribution) + def logp(self, value: TensorType) -> TensorType: + return super().logp(value).sum(-1) + + @override(TorchDistribution) + def entropy(self) -> TensorType: + return super().entropy().sum(-1) + + @override(TorchDistribution) + def kl(self, other: "TorchDistribution") -> TensorType: + return super().kl(other).sum(-1) + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + return tuple(np.prod(space.shape, dtype=np.int32) * 2) + + +@DeveloperAPI +class TorchDeterministic(Distribution): + """The distribution that returns the input values directly. + + This is similar to DiagGaussian with standard deviation zero (thus only + requiring the "mean" values as NN output). + + Note: entropy is always zero, ang logp and kl are not implemented. + + Example:: + + >>> m = TorchDeterministic(loc=torch.tensor([0.0, 0.0])) + >>> m.sample(sample_shape=(2,)) + tensor([[ 0.0, 0.0], [ 0.0, 0.0]]) + + Args: + loc: the determinsitic value to return + """ + + def __init__(self, loc: torch.Tensor) -> None: + super().__init__() + self.loc = loc + + @override(Distribution) + def sample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs, + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + if return_logp: + raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") + + if sample_shape is None: + sample_shape = torch.Size() + loc_shape = self.loc.shape + return torch.ones(sample_shape + loc_shape, device=self.loc.device) * self.loc + + def rsample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs, + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + raise NotImplementedError + + @override(Distribution) + def logp(self, value: TensorType, **kwargs) -> TensorType: + raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") + + @override(Distribution) + def entropy(self, **kwargs) -> TensorType: + raise torch.zeros_like(self.loc) + + @override(Distribution) + def kl(self, other: "Distribution", **kwargs) -> TensorType: + raise ValueError(f"Cannot return kl for {self.__class__.__name__}.") + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + # TODO: This was copied from previous code. Is this correct? add unit test. + return tuple(np.prod(space.shape, dtype=np.int32))