diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index 4de390e18c..46106fd288 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -511,6 +511,7 @@ def __init__(self, name = name_util.strip_invalid_chars(name) super(Bijector, self).__init__(name=name) self._name = name + # TODO(b/176242804): Infer `parameters` if not specified by the child class. self._parameters = self._no_dependency(parameters) self._graph_parents = self._no_dependency(graph_parents or []) @@ -648,6 +649,8 @@ def parameters(self): # Remove "self", "__class__", or other special variables. These can appear # if the subclass used: # `parameters = dict(locals())`. + if self._parameters is None: + return None return {k: v for k, v in self._parameters.items() if not k.startswith('__') and k != 'self'} @@ -689,6 +692,11 @@ def __eq__(self, other): return True def _get_parameterization(self): + if self.parameters is None: + # If a user-written bijector doesn't specify `parameters`, we must assume + # that all instances are unique. + # TODO(b/176242804): this can be removed if we always infer `parameters`. + return id(self) return self.parameters def __call__(self, value, name=None, **kwargs): diff --git a/tensorflow_probability/python/bijectors/bijector_test.py b/tensorflow_probability/python/bijectors/bijector_test.py index ab48289fef..c982da4655 100644 --- a/tensorflow_probability/python/bijectors/bijector_test.py +++ b/tensorflow_probability/python/bijectors/bijector_test.py @@ -280,6 +280,24 @@ def _get_parameterization(self): return id(self) +class UnspecifiedParameters(tfb.Bijector): + """A bijector that fails to pass `parameters` to the base class.""" + + def __init__(self, loc): + self._loc = loc + super(UnspecifiedParameters, self).__init__( + validate_args=False, + is_constant_jacobian=True, + forward_min_event_ndims=0, + name='unspecified_parameters') + + def _forward(self, x): + return x + self._loc + + def _forward_log_det_jacobian(self, x): + return tf.constant(0., x.dtype) + + @test_util.test_all_tf_execution_regimes class BijectorTestEventNdims(test_util.TestCase): @@ -440,6 +458,18 @@ def testUniqueCacheKey(self): self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1) self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1) + def testBijectorsWithUnspecifiedParametersDoNotShareCache(self): + bijector_1 = UnspecifiedParameters(loc=tf.constant(1., dtype=tf.float32)) + bijector_2 = UnspecifiedParameters(loc=tf.constant(2., dtype=tf.float32)) + + x = tf.constant(3., dtype=tf.float32) + y_1 = bijector_1.forward(x) + y_2 = bijector_2.forward(x) + + self.assertIsNot(y_1, y_2) + self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1) + self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1) + def testInstanceCache(self): instance_cache_bijector = tfb.Exp() instance_cache_bijector._cache = cache_util.BijectorCache( diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index e18bcad27d..75fe14f9f7 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -17,7 +17,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = '0' _MINOR_VERSION = '12' -_PATCH_VERSION = '0' +_PATCH_VERSION = '1' # When building releases, we can update this value on the release branch to # reflect the current release candidate ('rc0', 'rc1') or, finally, the official