Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic cdf and icdf implementation for scalar TransformedDistributions. #1853

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,6 @@ def variance(self):
a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2)
return jnp.where(self.concentration <= 2, jnp.inf, a)

def cdf(self, x):
return 1 - self.base_dist.cdf(1 / x)

def entropy(self):
return (
self.concentration
Expand Down Expand Up @@ -1205,9 +1202,6 @@ def mean(self):
def variance(self):
return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2)

def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)

Expand Down Expand Up @@ -1283,9 +1277,6 @@ def variance(self):
- self.mean**2
)

def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
log_low = jnp.log(self.low)
log_high = jnp.log(self.high)
Expand Down Expand Up @@ -2162,12 +2153,6 @@ def variance(self):
def support(self):
return constraints.greater_than(self.scale)

def cdf(self, value):
return 1 - jnp.power(self.scale / value, self.alpha)

def icdf(self, q):
return self.scale / jnp.power(1 - q, 1 / self.alpha)

def entropy(self):
return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha

Expand Down
18 changes: 18 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,24 @@ def mean(self):
def variance(self):
raise NotImplementedError

def cdf(self, value):
sign = 1
for transform in reversed(self.transforms):
sign *= transform.sign
value = transform.inv(value)
q = self.base_dist.cdf(value)
return jnp.where(sign < 0, 1 - q, q)

def icdf(self, q):
sign = 1
for transform in self.transforms:
sign *= transform.sign
q = jnp.where(sign < 0, 1 - q, q)
value = self.base_dist.icdf(q)
for transform in self.transforms:
value = transform(value)
return value


class FoldedDistribution(TransformedDistribution):
"""
Expand Down
35 changes: 35 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def inverse_shape(self, shape):
"""
return shape

@property
def sign(self):
"""
Sign of the derivative of the transform if it is bijective.
"""
raise NotImplementedError(
f"Transform `{self.__class__.__name__}` does not implement `sign`."
)

# Allow for pickle serialization of transforms.
def __getstate__(self):
attrs = {}
Expand Down Expand Up @@ -147,6 +156,10 @@ def domain(self):
def codomain(self):
return self._inv.domain

@property
def sign(self):
return self._inv.sign

@property
def inv(self):
return self._inv
Expand Down Expand Up @@ -231,6 +244,10 @@ def codomain(self):
else:
raise NotImplementedError

@property
def sign(self):
return jnp.sign(self.scale)

def __call__(self, x):
return self.loc + self.scale * x

Expand Down Expand Up @@ -309,6 +326,13 @@ def codomain(self):
self.parts[-1].codomain, output_event_dim - last_output_event_dim
)

@property
def sign(self):
sign = 1
for transform in self.parts:
sign *= transform.sign
return sign

def __call__(self, x):
for part in self.parts:
x = part(x)
Expand Down Expand Up @@ -509,6 +533,8 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):


class ExpTransform(Transform):
sign = 1

# TODO: refine domain/codomain logic through setters, especially when
# transforms for inverses are supported
def __init__(self, domain=constraints.real):
Expand Down Expand Up @@ -550,6 +576,8 @@ def __eq__(self, other):


class IdentityTransform(ParameterFreeTransform):
sign = 1

def __call__(self, x):
return x

Expand Down Expand Up @@ -912,9 +940,14 @@ def __eq__(self, other):
return False
return jnp.array_equal(self.exponent, other.exponent)

@property
def sign(self):
return jnp.sign(self.exponent)


class SigmoidTransform(ParameterFreeTransform):
codomain = constraints.unit_interval
sign = 1

def __call__(self, x):
return _clipped_expit(x)
Expand Down Expand Up @@ -1006,6 +1039,7 @@ class SoftplusTransform(ParameterFreeTransform):

domain = constraints.real
codomain = constraints.softplus_positive
sign = 1

def __call__(self, x):
return softplus(x)
Expand Down Expand Up @@ -1177,6 +1211,7 @@ class ReshapeTransform(Transform):

domain = constraints.real
codomain = constraints.real
sign = 1

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
Expand Down
7 changes: 4 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def get_sp_dist(jax_dist):
T(dist.HalfNormal, 1.0),
T(dist.HalfNormal, np.array([1.0, 2.0])),
T(_ImproperWrapper, constraints.positive, (), (3,)),
T(dist.InverseGamma, np.array([3.1]), np.array([[2.0], [3.0]])),
T(dist.InverseGamma, np.array([1.7]), np.array([[2.0], [3.0]])),
T(dist.InverseGamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
T(dist.Kumaraswamy, 10.0, np.array([2.0, 3.0])),
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,))
quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape())
try:
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.StudentT) else 1e-5
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.LogNormal, dist.StudentT) else 1e-5
if d.shape() == () and not d.is_discrete:
assert_allclose(
jax.vmap(jax.grad(d.cdf))(samples),
Expand All @@ -1585,7 +1586,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5)
assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol)
except NotImplementedError:
pass
pytest.skip("cdf/icdf not implemented")

# test against scipy
if not sp_dist:
Expand All @@ -1599,7 +1600,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
expected_icdf = sp_dist.ppf(quantiles)
assert_allclose(actual_icdf, expected_icdf, atol=1e-4, rtol=1e-4)
except NotImplementedError:
pass
pytest.skip("cdf/icdf not implemented")


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL)
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def test_bijective_transforms(transform, shape):
)
slogdet = jnp.linalg.slogdet(jac)
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)
assert transform.domain.event_dim or jnp.allclose(
jnp.sign(jnp.diagonal(jac, axis1=-1, axis2=-2)), transform.sign
)


def test_batched_recursive_linear_transform():
Expand Down
Loading