Skip to content

Commit

Permalink
Merge pull request #1205 from jburnim/r0.12
Browse files Browse the repository at this point in the history
Prepare branch for TFP 0.12.1 release
  • Loading branch information
jburnim authored Dec 28, 2020
2 parents dcd59ed + c1818a8 commit 43a9d6c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
8 changes: 8 additions & 0 deletions tensorflow_probability/python/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def __init__(self,
name = name_util.strip_invalid_chars(name)
super(Bijector, self).__init__(name=name)
self._name = name
# TODO(b/176242804): Infer `parameters` if not specified by the child class.
self._parameters = self._no_dependency(parameters)

self._graph_parents = self._no_dependency(graph_parents or [])
Expand Down Expand Up @@ -648,6 +649,8 @@ def parameters(self):
# Remove "self", "__class__", or other special variables. These can appear
# if the subclass used:
# `parameters = dict(locals())`.
if self._parameters is None:
return None
return {k: v for k, v in self._parameters.items()
if not k.startswith('__') and k != 'self'}

Expand Down Expand Up @@ -689,6 +692,11 @@ def __eq__(self, other):
return True

def _get_parameterization(self):
if self.parameters is None:
# If a user-written bijector doesn't specify `parameters`, we must assume
# that all instances are unique.
# TODO(b/176242804): this can be removed if we always infer `parameters`.
return id(self)
return self.parameters

def __call__(self, value, name=None, **kwargs):
Expand Down
30 changes: 30 additions & 0 deletions tensorflow_probability/python/bijectors/bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@ def _get_parameterization(self):
return id(self)


class UnspecifiedParameters(tfb.Bijector):
"""A bijector that fails to pass `parameters` to the base class."""

def __init__(self, loc):
self._loc = loc
super(UnspecifiedParameters, self).__init__(
validate_args=False,
is_constant_jacobian=True,
forward_min_event_ndims=0,
name='unspecified_parameters')

def _forward(self, x):
return x + self._loc

def _forward_log_det_jacobian(self, x):
return tf.constant(0., x.dtype)


@test_util.test_all_tf_execution_regimes
class BijectorTestEventNdims(test_util.TestCase):

Expand Down Expand Up @@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)

def testBijectorsWithUnspecifiedParametersDoNotShareCache(self):
bijector_1 = UnspecifiedParameters(loc=tf.constant(1., dtype=tf.float32))
bijector_2 = UnspecifiedParameters(loc=tf.constant(2., dtype=tf.float32))

x = tf.constant(3., dtype=tf.float32)
y_1 = bijector_1.forward(x)
y_2 = bijector_2.forward(x)

self.assertIsNot(y_1, y_2)
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)

def testInstanceCache(self):
instance_cache_bijector = tfb.Exp()
instance_cache_bijector._cache = cache_util.BijectorCache(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '0'
_MINOR_VERSION = '12'
_PATCH_VERSION = '0'
_PATCH_VERSION = '1'

# When building releases, we can update this value on the release branch to
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official
Expand Down

0 comments on commit 43a9d6c

Please sign in to comment.