Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Finish testing matthewearl's Gaussian squashed gaussian PR #13292

Closed
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8e63d3c
Implement GaussianSquashedGaussian. Still buggy
matthewearl Mar 13, 2020
005c524
fix bug in gsg logp
matthewearl Mar 13, 2020
ba69bb7
Fix bugs in KL and entropy methods
matthewearl Mar 13, 2020
113fc4f
Initial attempt at integrating GSG into catalog
matthewearl Mar 13, 2020
c8e53ce
Fix up the shapes returned by SG
matthewearl Mar 14, 2020
f4521f7
Reformatting according to scripts/format.sh
matthewearl Mar 15, 2020
b0c2323
code review markup
matthewearl Apr 14, 2020
0e161fc
Bound loc for numerical stability
matthewearl Apr 14, 2020
511eef6
Merge branch 'master' of github.com:ray-project/ray into me/gsg
matthewearl Apr 16, 2020
86527ec
Merge branch 'me/gsg' of github.com:matthewearl/ray into me/gsg
matthewearl Apr 16, 2020
f226d2e
Fix squashed gaussian unit test
matthewearl Apr 16, 2020
3e1d345
Fix gaussian squashed gaussian following the previous commit
matthewearl Apr 16, 2020
9c9b8bc
add test for gaussian squashed gaussian
matthewearl Apr 16, 2020
731afbd
linter fixes
matthewearl Apr 17, 2020
a80db8b
WIP.
sven1977 Jan 8, 2021
cd9cef2
Merge branch 'master' of https://github.com/ray-project/ray into me/gsg
sven1977 Jan 11, 2021
7e89931
WIP.
sven1977 Jan 11, 2021
9218430
LINT.
sven1977 Jan 11, 2021
ed7d261
Fix.
sven1977 Jan 12, 2021
544b730
Merge branch 'master' of https://github.com/ray-project/ray into gaus…
sven1977 Jan 12, 2021
6098dda
Torch version and LINT.
sven1977 Jan 12, 2021
37f6986
LINT.
sven1977 Jan 12, 2021
32f4201
Fix and LINT.
sven1977 Jan 13, 2021
44d96f9
Merge branch 'master' of https://github.com/ray-project/ray into gaus…
sven1977 Jan 13, 2021
c61739c
wip
sven1977 Jan 13, 2021
4f131af
Merge branch 'master' of https://github.com/ray-project/ray into gaus…
sven1977 Feb 8, 2021
c6319c1
Merge branch 'master' of https://github.com/ray-project/ray into gaus…
sven1977 Apr 11, 2021
ec3b6dc
LINT.
sven1977 Apr 11, 2021
4878362
fix and LINT.
sven1977 Apr 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.models.tf.tf_action_dist import Categorical, \
Deterministic, DiagGaussian, Dirichlet, \
GaussianSquashedGaussian, \
MultiActionDistribution, MultiCategorical
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
Expand Down Expand Up @@ -209,10 +210,16 @@ def get_action_dist(
"Consider reshaping this into a single dimension, "
"using a custom action distribution, "
"using a Tuple action space, or the multi-agent API.")
# TODO(sven): Check for bounds and return SquashedNormal, etc..
if dist_type is None:
dist_cls = TorchDiagGaussian if framework == "torch" \
else DiagGaussian
if framework == "torch":
return TorchDiagGaussian
elif np.any(action_space.bounded_below &
action_space.bounded_above):
return ModelCatalog._make_bounded_dist(action_space)
else:
dist = DiagGaussian
#dist_cls = TorchDiagGaussian if framework == "torch" \
# else DiagGaussian
elif dist_type == "deterministic":
dist_cls = TorchDeterministic if framework == "torch" \
else Deterministic
Expand Down Expand Up @@ -710,6 +717,29 @@ def _get_multi_action_distribution(dist_class, action_space, config,
input_lens=input_lens), int(sum(input_lens))
return dist_class

@staticmethod
def _make_bounded_dist(action_space):
child_dists = []

low = np.ravel(action_space.low)
high = np.ravel(action_space.high)

for l, h in zip(low, high):
if not np.isinf(l) and not np.isinf(h):
dist = partial(GaussianSquashedGaussian, low=l, high=h)
else:
dist = DiagGaussian
child_dists.append(dist)

if len(child_dists) == 1:
return dist, 2

return partial(
MultiActionDistribution,
action_space=action_space,
child_distributions=child_dists,
input_lens=[2] * len(child_dists)), 2 * len(child_dists)

@staticmethod
def _validate_config(config: ModelConfigDict, framework: str) -> None:
"""Validates a given model config dict.
Expand Down
35 changes: 32 additions & 3 deletions rllib/models/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from ray.rllib.models.jax.jax_action_dist import JAXCategorical
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
DiagGaussian, GumbelSoftmax, MultiActionDistribution, MultiCategorical, \
SquashedGaussian
DiagGaussian, GaussianSquashedGaussian, GumbelSoftmax, \
MultiActionDistribution, MultiCategorical, SquashedGaussian
from ray.rllib.models.torch.torch_action_dist import TorchBeta, \
TorchCategorical, TorchDiagGaussian, TorchMultiActionDistribution, \
TorchMultiCategorical, TorchSquashedGaussian
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_squashed_gaussian(self):
check(np.sum(sampled_action_logp), np.sum(log_prob), rtol=0.05)

# NN output.
means = np.array([[0.1, 0.2, 0.3, 0.4, 50.0],
means = np.array([[0.1, 0.2, 0.3, 0.4, 2.9],
[-0.1, -0.2, -0.3, -0.4, -1.0]])
log_stds = np.array([[0.8, -0.2, 0.3, -1.0, 2.0],
[0.7, -0.3, 0.4, -0.9, 2.0]])
Expand Down Expand Up @@ -372,6 +372,35 @@ def test_diag_gaussian(self):
outs = sess.run(outs)
check(outs, log_prob, decimals=4)

def test_gaussian_squashed_gaussian(self):
for fw, sess in framework_iterator(frameworks="tf", session=True):
inputs1 = tf.constant([[-0.5, 0.2, np.log(0.1), np.log(0.5)],
[0.6, 0.8, np.log(0.7), np.log(0.8)],
[-10.0, 1.2, np.log(0.9), np.log(1.0)]])

inputs2 = tf.constant([[0.2, 0.3, np.log(0.2), np.log(0.4)],
[0.6, 0.8, np.log(0.7), np.log(0.8)],
[-11.0, 1.2, np.log(0.9), np.log(1.0)]])

gsg_dist1 = GaussianSquashedGaussian(inputs1, None)
gsg_dist2 = GaussianSquashedGaussian(inputs2, None)

# KL, entropy, and logp values have been verified empirically.
check(sess.run(gsg_dist1.kl(gsg_dist2)),
np.array([6.532504, 0., 0.]))
check(sess.run(gsg_dist1.entropy()),
np.array([-0.74827796, 0.7070056, -4.971432]))
x = tf.constant([[-0.3939393939393939]])
check(sess.run(gsg_dist1.logp(x)),
np.array([0.736003, -3.1547096, -6.5595593]))

# This is just the squashed distribution means. Verified using
# _unsquash (which was itself verified as part of the logp test).
expected = np.array([[-0.41861248, 0.1745522],
[0.49179232, 0.62231755],
[-0.99906087, 0.81425166]])
check(sess.run(gsg_dist1.deterministic_sample()), expected)

def test_beta(self):
input_space = Box(-2.0, 1.0, shape=(2000, 10))
low, high = -1.0, 2.0
Expand Down
149 changes: 130 additions & 19 deletions rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,8 @@ def required_model_output_shape(
return np.prod(action_space.shape) * 2


class SquashedGaussian(TFActionDistribution):
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
class _SquashedGaussianBase(TFActionDistribution):
"""A diagonal gaussian distribution, squashed into bounded support."""

def __init__(self,
inputs: List[TensorType],
Expand All @@ -300,11 +296,17 @@ def __init__(self,
"""
assert tfp is not None
mean, log_std = tf.split(inputs, 2, axis=-1)
# Clip `scale` values (coming from NN) to reasonable values.
log_std = tf.clip_by_value(log_std, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
std = tf.exp(log_std)
self._num_vars = mean.shape[1]
assert log_std.shape[1] == self._num_vars
# Clip `std` values (coming from NN) to reasonable values.
self.log_std = tf.clip_by_value(log_std, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
# Clip loc too, for numerical stability reasons.
mean = tf.clip_by_value(mean, -3, 3)
std = tf.exp(self.log_std)
self.distr = tfp.distributions.Normal(loc=mean, scale=std)
assert len(self.distr.loc.shape) == 2
assert len(self.distr.scale.shape) == 2
assert np.all(np.less(low, high))
self.low = low
self.high = high
Expand All @@ -313,27 +315,79 @@ def __init__(self,
@override(ActionDistribution)
def deterministic_sample(self) -> TensorType:
mean = self.distr.mean()
return self._squash(mean)

@override(TFActionDistribution)
def _build_sample_op(self) -> TensorType:
return self._squash(self.distr.sample())
assert len(mean.shape) == 2
s = self._squash(mean)
assert len(s.shape) == 2
return s

@override(ActionDistribution)
def logp(self, x: TensorType) -> TensorType:
# Unsquash values (from [low,high] to ]-inf,inf[)
assert len(x.shape) >= 2, "First dim batch, second dim variable"
unsquashed_values = self._unsquash(x)
# Get log prob of unsquashed values from our Normal.
log_prob_gaussian = self.distr.log_prob(unsquashed_values)
# For safety reasons, clamp somehow, only then sum up.
log_prob_gaussian = tf.clip_by_value(log_prob_gaussian, -100, 100)
log_prob_gaussian = tf.reduce_sum(log_prob_gaussian, axis=-1)
# Get log-prob for squashed Gaussian.
return tf.reduce_sum(log_prob_gaussian -
self._log_squash_grad(unsquashed_values), axis=1)

@override(TFActionDistribution)
def _build_sample_op(self):
s = self._squash(self.distr.sample())
assert len(s.shape) == 2
return s

def _squash(self, unsquashed_values):
"""Squash an array element-wise into the (high, low) range

Arguments:
unsquashed_values: values to be squashed

Returns:
The squashed values. The output shape is `unsquashed_values.shape`

"""
raise NotImplementedError

def _unsquash(self, values):
"""Unsquash an array element-wise from the (high, low) range

Arguments:
squashed_values: values to be unsquashed

Returns:
The unsquashed values. The output shape is `squashed_values.shape`

"""
raise NotImplementedError

def _log_squash_grad(self, unsquashed_values):
"""Log gradient of _squash with respect to its argument.

Arguments:
squashed_values: Point at which to measure the gradient.

Returns:
The gradient at the given point. The output shape is
`squashed_values.shape`.

"""
raise NotImplementedError


class SquashedGaussian(_SquashedGaussianBase):
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""

def _log_squash_grad(self, unsquashed_values):
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
log_prob = log_prob_gaussian - tf.reduce_sum(
tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
axis=-1)
return log_prob
return tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER)

@override(ActionDistribution)
def entropy(self) -> TensorType:
Expand Down Expand Up @@ -422,6 +476,63 @@ def required_model_output_shape(
return np.prod(action_space.shape) * 2


class GaussianSquashedGaussian(_SquashedGaussianBase):
"""A gaussian CDF-squashed Gaussian distribution.

The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
# Chosen to match the standard logistic variance, so that:
# Var(N(0, 2 * _SCALE)) = Var(Logistic(0, 1))
_SCALE = 0.5 * 1.8137

@override(ActionDistribution)
def kl(self, other):
# KL(self || other) is just the KL of the two unsquashed distributions.
assert isinstance(other, GaussianSquashedGaussian)

mean = self.distr.loc
std = self.distr.scale

other_mean = other.distr.loc
other_std = other.distr.scale

return tf.reduce_sum((other.log_std - self.log_std +
(tf.square(std) + tf.square(mean - other_mean)) /
(2.0 * tf.square(other_std)) - 0.5), axis=1)

def entropy(self):
# Entropy is:
# -KL(self.distr || N(0, _SCALE)) + log(high - low)
# where the latter distribution's CDF is used to do the squashing.

mean = self.distr.loc
std = self.distr.scale

return tf.reduce_sum(tf.log(self.high - self.low) -
(tf.log(self._SCALE) - self.log_std +
(tf.square(std) + tf.square(mean)) /
(2.0 * tf.square(self._SCALE)) - 0.5), axis=1)

def _log_squash_grad(self, unsquashed_values):
squash_dist = tfp.distributions.Normal(loc=0, scale=self._SCALE)
log_grad = squash_dist.log_prob(value=unsquashed_values)
log_grad += tf.log(self.high - self.low)
return log_grad

def _squash(self, raw_values):
# Make sure raw_values are not too high/low (such that tanh would
# return exactly 1.0/-1.0, which would lead to +/-inf log-probs).

values = tfp.bijectors.NormalCDF().forward(raw_values / self._SCALE)
return (tf.clip_by_value(values, SMALL_NUMBER, 1.0 - SMALL_NUMBER) *
(self.high - self.low) + self.low)

def _unsquash(self, values):
return self._SCALE * tfp.bijectors.NormalCDF().inverse(
(values - self.low) / (self.high - self.low))


class Deterministic(TFActionDistribution):
"""Action distribution that returns the input values directly.

Expand Down