Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoids setting jax tracer as lazy property attribute #1843

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma

from numpyro.util import not_jax_tracer

# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple(
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
Expand Down Expand Up @@ -692,7 +694,8 @@ def __get__(self, instance, obj_type=None):
if instance is None:
return self
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
if not_jax_tracer(value):
setattr(instance, self.wrapped.__name__, value)
return value


Expand Down
44 changes: 44 additions & 0 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import scipy

import jax
from jax import lax, random, vmap
import jax.numpy as jnp
from jax.scipy.special import expit, xlog1py, xlogy

import numpyro.distributions as dist
from numpyro.distributions.util import (
add_diag,
binary_cross_entropy_with_logits,
Expand Down Expand Up @@ -182,3 +184,45 @@ def test_add_diag(matrix_shape: tuple, diag_shape: tuple) -> None:
expected = matrix + diag[..., None] * jnp.eye(matrix.shape[-1])
actual = add_diag(matrix, diag)
np.testing.assert_allclose(actual, expected)


@pytest.mark.parametrize(
"my_dist",
[
dist.TruncatedNormal(low=-1.0, high=2.0),
dist.TruncatedCauchy(low=-5, high=10),
dist.TruncatedDistribution(dist.StudentT(3), low=1.5),
],
)
def test_no_tracer_leak_at_lazy_property_log_prob(my_dist):
"""
Tests that truncated distributions, which use @lazy_property
values in their log_prob() methods, do not
have tracer leakage when log_prob() is called.
Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and
https://github.com/CDCgov/multisignal-epi-inference/issues/282
"""
jit_lp = jax.jit(my_dist.log_prob)
with jax.check_tracer_leaks():
jit_lp(1.0)


@pytest.mark.parametrize(
"my_dist",
[
dist.TruncatedNormal(low=-1.0, high=2.0),
dist.TruncatedCauchy(low=-5, high=10),
dist.TruncatedDistribution(dist.StudentT(3), low=1.5),
],
)
def test_no_tracer_leak_at_lazy_property_sample(my_dist):
"""
Tests that truncated distributions, which use @lazy_property
values in their sample() methods, do not
have tracer leakage when sample() is called.
Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and
https://github.com/CDCgov/multisignal-epi-inference/issues/282
"""
jit_sample = jax.jit(my_dist.sample)
with jax.check_tracer_leaks():
jit_sample(jax.random.key(5))
Loading