Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian Squashed Gaussian #7609

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
Deterministic, DiagGaussian, MultiActionDistribution, Dirichlet
Deterministic, DiagGaussian, MultiActionDistribution, Dirichlet, \
GaussianSquashedGaussian
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM
Expand Down Expand Up @@ -104,6 +105,29 @@ class ModelCatalog:
>>> action = dist.sample()
"""

@staticmethod
def _make_bounded_dist(action_space):
matthewearl marked this conversation as resolved.
Show resolved Hide resolved
child_dists = []

low = np.ravel(action_space.low)
high = np.ravel(action_space.high)

for l, h in zip(low, high):
if not np.isinf(l) and not np.isinf(h):
dist = partial(GaussianSquashedGaussian, low=l, high=h)
else:
dist = DiagGaussian
child_dists.append(dist)

if len(child_dists) == 1:
return dist, 2

return partial(
MultiActionDistribution,
action_space=action_space,
child_distributions=child_dists,
input_lens=[2] * len(child_dists)), 2 * len(child_dists)

@staticmethod
@DeveloperAPI
def get_action_dist(action_space,
Expand Down Expand Up @@ -147,9 +171,16 @@ def get_action_dist(action_space,
"Consider reshaping this into a single dimension, "
"using a custom action distribution, "
"using a Tuple action space, or the multi-agent API.")
# TODO(sven): Check for bounds and return SquashedNormal, etc..
if dist_type is None:
dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
any_bounded = np.any(
action_space.bounded_below & action_space.bounded_above)
if framework != "tf":
return TorchDiagGaussian
elif np.any(action_space.bounded_below &
action_space.bounded_above):
return ModelCatalog._make_bounded_dist(action_space)
else:
dist = TorchDiagGaussian
matthewearl marked this conversation as resolved.
Show resolved Hide resolved
elif dist_type == "deterministic":
dist = Deterministic
# Discrete Space -> Categorical.
Expand Down
159 changes: 132 additions & 27 deletions rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,8 @@ def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2


class SquashedGaussian(TFActionDistribution):
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
class _SquashedGaussianBase(TFActionDistribution):
"""A univariate gaussian distribution, squashed into bounded support."""

def __init__(self, inputs, model, low=-1.0, high=1.0):
"""Parameterizes the distribution via `inputs`.
Expand All @@ -207,48 +203,100 @@ def __init__(self, inputs, model, low=-1.0, high=1.0):
(excluding this value).
"""
assert tfp is not None
loc, log_scale = tf.split(inputs, 2, axis=-1)
loc, log_scale = inputs[:, 0], inputs[:, 1]
# Clip `scale` values (coming from NN) to reasonable values.
log_scale = tf.clip_by_value(log_scale, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
scale = tf.exp(log_scale)
self.log_std = tf.clip_by_value(log_scale, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
scale = tf.exp(self.log_std)
self.distr = tfp.distributions.Normal(loc=loc, scale=scale)
assert len(self.distr.loc.shape) == 1
assert len(self.distr.scale.shape) == 1
assert np.all(np.less(low, high))
self.low = low
self.high = high
super().__init__(inputs, model)

@override(TFActionDistribution)
def sampled_action_logp(self):
unsquashed_values = self._unsquash(self.sample_op)
log_prob = tf.reduce_sum(
self.distr.log_prob(unsquashed_values), axis=-1)
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
log_prob -= tf.math.reduce_sum(
tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
axis=-1)
return log_prob

@override(ActionDistribution)
def deterministic_sample(self):
mean = self.distr.mean()
return self._squash(mean)
assert len(mean.shape) == 1, "Shape should be batch dim only"
s = self._squash(mean)
assert len(s.shape) == 1
return s[:, None]

@override(ActionDistribution)
def logp(self, x):
assert len(x.shape) >= 2, "First dim batch, second dim variable"
unsquashed_values = self._unsquash(x[:, 0])
log_prob = self.distr.log_prob(value=unsquashed_values)
return log_prob - self._log_squash_grad(unsquashed_values)

@override(TFActionDistribution)
def _build_sample_op(self):
return self._squash(self.distr.sample())
s = self._squash(self.distr.sample())
assert len(s.shape) == 1
return s[:, None]

@override(ActionDistribution)
def logp(self, x):
unsquashed_values = self._unsquash(x)
def _squash(self, unsquashed_values):
"""Squash an array element-wise into the (high, low) range

Arguments:
unsquashed_values: values to be squashed

Returns:
The squashed values. The output shape is `unsquashed_values.shape`

"""
raise NotImplementedError

def _unsquash(self, values):
"""Unsquash an array element-wise from the (high, low) range

Arguments:
squashed_values: values to be unsquashed

Returns:
The unsquashed values. The output shape is `squashed_values.shape`

"""
raise NotImplementedError

def _log_squash_grad(self, unsquashed_values):
"""Log gradient of _squash with respect to its argument.

Arguments:
squashed_values: Point at which to measure the gradient.

Returns:
The gradient at the given point. The output shape is
`squashed_values.shape`.

"""
raise NotImplementedError


class SquashedGaussian(_SquashedGaussianBase):
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""

@override(TFActionDistribution)
def sampled_action_logp(self):
unsquashed_values = self._unsquash(self.sample_op)
log_prob = tf.reduce_sum(
self.distr.log_prob(value=unsquashed_values), axis=-1)
self.distr.log_prob(unsquashed_values), axis=-1)
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
log_prob -= tf.math.reduce_sum(
tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
axis=-1)
return log_prob

def _log_squash_grad(self, unsquashed_values):
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
return tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER)

def _squash(self, raw_values):
# Make sure raw_values are not too high/low (such that tanh would
# return exactly 1.0/-1.0, which would lead to +/-inf log-probs).
Expand All @@ -263,6 +311,63 @@ def _unsquash(self, values):
(self.high - self.low) * 2.0 - 1.0)


class GaussianSquashedGaussian(_SquashedGaussianBase):
"""A gaussian CDF-squashed Gaussian distribution.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
# Chosen to match the standard logistic variance, so that:
# Var(N(0, 2 * _SCALE)) = Var(Logistic(0, 1))
_SCALE = 0.5 * 1.8137

@override(ActionDistribution)
def kl(self, other):
# KL(self || other) is just the KL of the two unsquashed distributions.
assert isinstance(other, GaussianSquashedGaussian)

mean = self.distr.loc
std = self.distr.scale

other_mean = other.distr.loc
other_std = other.distr.scale

return (other.log_std - self.log_std +
(tf.square(std) + tf.square(mean - other_mean)) /
(2.0 * tf.square(other_std)) - 0.5)

def entropy(self):
# Entropy is:
# -KL(self.distr || N(0, _SCALE)) + log(high - low)
# where the latter distribution's CDF is used to do the squashing.

mean = self.distr.loc
std = self.distr.scale

return (tf.log(self.high - self.low) -
(tf.log(self._SCALE) - self.log_std +
(tf.square(std) + tf.square(mean)) /
(2.0 * tf.square(self._SCALE)) - 0.5))

def _log_squash_grad(self, unsquashed_values):
squash_dist = tfp.distributions.Normal(loc=0, scale=self._SCALE)
log_grad = squash_dist.log_prob(value=unsquashed_values)
log_grad += tf.log(self.high - self.low)
return log_grad

def _squash(self, raw_values):
# Make sure raw_values are not too high/low (such that tanh would
# return exactly 1.0/-1.0, which would lead to +/-inf log-probs).

values = tfp.bijectors.NormalCDF().forward(raw_values / self._SCALE)
return (tf.clip_by_value(values, SMALL_NUMBER, 1.0 - SMALL_NUMBER) *
(self.high - self.low) + self.low)

def _unsquash(self, values):
return self._SCALE * tfp.bijectors.NormalCDF().inverse(
(values - self.low) / (self.high - self.low))


class Deterministic(TFActionDistribution):
"""Action distribution that returns the input values directly.

Expand Down