Skip to content

Commit

Permalink
Add sign for bijective scalar transforms and generic cdf/icdf i…
Browse files Browse the repository at this point in the history
…mplementation for `TransformedDistribution`s. (#1853)
  • Loading branch information
tillahoffmann authored Aug 26, 2024
1 parent e0d450b commit d52209c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
15 changes: 0 additions & 15 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,6 @@ def variance(self):
a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2)
return jnp.where(self.concentration <= 2, jnp.inf, a)

def cdf(self, x):
return 1 - self.base_dist.cdf(1 / x)

def entropy(self):
return (
self.concentration
Expand Down Expand Up @@ -1205,9 +1202,6 @@ def mean(self):
def variance(self):
return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2)

def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)

Expand Down Expand Up @@ -1283,9 +1277,6 @@ def variance(self):
- self.mean**2
)

def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
log_low = jnp.log(self.low)
log_high = jnp.log(self.high)
Expand Down Expand Up @@ -2162,12 +2153,6 @@ def variance(self):
def support(self):
return constraints.greater_than(self.scale)

def cdf(self, value):
return 1 - jnp.power(self.scale / value, self.alpha)

def icdf(self, q):
return self.scale / jnp.power(1 - q, 1 / self.alpha)

def entropy(self):
return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha

Expand Down
18 changes: 18 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,24 @@ def mean(self):
def variance(self):
raise NotImplementedError

def cdf(self, value):
sign = 1
for transform in reversed(self.transforms):
sign *= transform.sign
value = transform.inv(value)
q = self.base_dist.cdf(value)
return jnp.where(sign < 0, 1 - q, q)

def icdf(self, q):
sign = 1
for transform in self.transforms:
sign *= transform.sign
q = jnp.where(sign < 0, 1 - q, q)
value = self.base_dist.icdf(q)
for transform in self.transforms:
value = transform(value)
return value


class FoldedDistribution(TransformedDistribution):
"""
Expand Down
35 changes: 35 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def inverse_shape(self, shape):
"""
return shape

@property
def sign(self):
"""
Sign of the derivative of the transform if it is bijective.
"""
raise NotImplementedError(
f"Transform `{self.__class__.__name__}` does not implement `sign`."
)

# Allow for pickle serialization of transforms.
def __getstate__(self):
attrs = {}
Expand Down Expand Up @@ -147,6 +156,10 @@ def domain(self):
def codomain(self):
return self._inv.domain

@property
def sign(self):
return self._inv.sign

@property
def inv(self):
return self._inv
Expand Down Expand Up @@ -231,6 +244,10 @@ def codomain(self):
else:
raise NotImplementedError

@property
def sign(self):
return jnp.sign(self.scale)

def __call__(self, x):
return self.loc + self.scale * x

Expand Down Expand Up @@ -309,6 +326,13 @@ def codomain(self):
self.parts[-1].codomain, output_event_dim - last_output_event_dim
)

@property
def sign(self):
sign = 1
for transform in self.parts:
sign *= transform.sign
return sign

def __call__(self, x):
for part in self.parts:
x = part(x)
Expand Down Expand Up @@ -509,6 +533,8 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):


class ExpTransform(Transform):
sign = 1

# TODO: refine domain/codomain logic through setters, especially when
# transforms for inverses are supported
def __init__(self, domain=constraints.real):
Expand Down Expand Up @@ -550,6 +576,8 @@ def __eq__(self, other):


class IdentityTransform(ParameterFreeTransform):
sign = 1

def __call__(self, x):
return x

Expand Down Expand Up @@ -912,9 +940,14 @@ def __eq__(self, other):
return False
return jnp.array_equal(self.exponent, other.exponent)

@property
def sign(self):
return jnp.sign(self.exponent)


class SigmoidTransform(ParameterFreeTransform):
codomain = constraints.unit_interval
sign = 1

def __call__(self, x):
return _clipped_expit(x)
Expand Down Expand Up @@ -1006,6 +1039,7 @@ class SoftplusTransform(ParameterFreeTransform):

domain = constraints.real
codomain = constraints.softplus_positive
sign = 1

def __call__(self, x):
return softplus(x)
Expand Down Expand Up @@ -1177,6 +1211,7 @@ class ReshapeTransform(Transform):

domain = constraints.real
codomain = constraints.real
sign = 1

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
Expand Down
7 changes: 4 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def get_sp_dist(jax_dist):
T(dist.HalfNormal, 1.0),
T(dist.HalfNormal, np.array([1.0, 2.0])),
T(_ImproperWrapper, constraints.positive, (), (3,)),
T(dist.InverseGamma, np.array([3.1]), np.array([[2.0], [3.0]])),
T(dist.InverseGamma, np.array([1.7]), np.array([[2.0], [3.0]])),
T(dist.InverseGamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
T(dist.Kumaraswamy, 10.0, np.array([2.0, 3.0])),
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,))
quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape())
try:
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.StudentT) else 1e-5
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.LogNormal, dist.StudentT) else 1e-5
if d.shape() == () and not d.is_discrete:
assert_allclose(
jax.vmap(jax.grad(d.cdf))(samples),
Expand All @@ -1585,7 +1586,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5)
assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol)
except NotImplementedError:
pass
pytest.skip("cdf/icdf not implemented")

# test against scipy
if not sp_dist:
Expand All @@ -1599,7 +1600,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
expected_icdf = sp_dist.ppf(quantiles)
assert_allclose(actual_icdf, expected_icdf, atol=1e-4, rtol=1e-4)
except NotImplementedError:
pass
pytest.skip("cdf/icdf not implemented")


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL)
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def test_bijective_transforms(transform, shape):
)
slogdet = jnp.linalg.slogdet(jac)
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)
assert transform.domain.event_dim or jnp.allclose(
jnp.sign(jnp.diagonal(jac, axis1=-1, axis2=-2)), transform.sign
)


def test_batched_recursive_linear_transform():
Expand Down

0 comments on commit d52209c

Please sign in to comment.