From 711810162be6163ada14bb2c9f1306535ed0f428 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Mon, 21 Aug 2023 15:18:59 +0200 Subject: [PATCH 01/11] dispatch _promote_batch_shape_expanded to Independent --- numpyro/distributions/batch_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index cc13dbf40..4023f3d66 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -53,6 +53,7 @@ ExpandedDistribution, MaskedDistribution, Unit, + Independent, ) from numpyro.distributions.transforms import ( AffineTransform, @@ -524,7 +525,7 @@ def _default_promote_batch_shape(d: Distribution): @promote_batch_shape.register -def _promote_batch_shape_expanded(d: ExpandedDistribution): +def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]): orig_delta_batch_shape = d.batch_shape[ : len(d.batch_shape) - len(d.base_dist.batch_shape) ] From ed6a716bcbede76b6abb78cf5f79fd72804453d3 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Mon, 21 Aug 2023 15:28:13 +0200 Subject: [PATCH 02/11] formatting --- numpyro/distributions/batch_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 4023f3d66..5dad06128 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -51,9 +51,9 @@ from numpyro.distributions.distribution import ( Distribution, ExpandedDistribution, + Independent, MaskedDistribution, Unit, - Independent, ) from numpyro.distributions.transforms import ( AffineTransform, From ccb3108fd53d3fbee37d2112a09b71f01bdefc78 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Mon, 21 Aug 2023 17:04:51 +0200 Subject: [PATCH 03/11] union type workaround --- numpyro/distributions/batch_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 5dad06128..95ea9211b 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -524,8 +524,8 @@ def _default_promote_batch_shape(d: Distribution): return new_self -@promote_batch_shape.register -def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]): +@promote_batch_shape.register(Independent) +def _promote_batch_shape_expanded(d: ExpandedDistribution): orig_delta_batch_shape = d.batch_shape[ : len(d.batch_shape) - len(d.base_dist.batch_shape) ] @@ -563,7 +563,6 @@ def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]): new_self.base_dist = new_base_dist return new_self - @promote_batch_shape.register def _promote_batch_shape_masked(d: MaskedDistribution): new_self = copy.copy(d) From 4735a69c93ee07db007e0367e92612e0fced64c3 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Mon, 21 Aug 2023 17:16:09 +0200 Subject: [PATCH 04/11] linting --- numpyro/distributions/batch_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 95ea9211b..20edf5cf9 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -563,6 +563,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): new_self.base_dist = new_base_dist return new_self + @promote_batch_shape.register def _promote_batch_shape_masked(d: MaskedDistribution): new_self = copy.copy(d) From 273757c68f27f4dce88b181e90e0aac67b44be81 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Wed, 23 Aug 2023 11:27:46 +0200 Subject: [PATCH 05/11] separate register --- numpyro/distributions/batch_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 20edf5cf9..253152470 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -524,7 +524,7 @@ def _default_promote_batch_shape(d: Distribution): return new_self -@promote_batch_shape.register(Independent) +@promote_batch_shape.register def _promote_batch_shape_expanded(d: ExpandedDistribution): orig_delta_batch_shape = d.batch_shape[ : len(d.batch_shape) - len(d.base_dist.batch_shape) @@ -564,6 +564,9 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): return new_self +promote_batch_shape.register(Independent, _promote_batch_shape_expanded) + + @promote_batch_shape.register def _promote_batch_shape_masked(d: MaskedDistribution): new_self = copy.copy(d) From 44b4b211172333ae3e5ba6fc89fb7e09e2c2c97c Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Thu, 24 Aug 2023 13:26:11 +0200 Subject: [PATCH 06/11] add test of scan plate mask --- test/contrib/test_control_flow.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index c4b2e641d..dc17c3eb5 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -10,7 +10,7 @@ import numpyro from numpyro.contrib.control_flow import cond, scan import numpyro.distributions as dist -from numpyro.handlers import seed, substitute, trace +from numpyro.handlers import seed, substitute, trace, mask from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO from numpyro.infer.util import potential_energy @@ -210,3 +210,21 @@ def transition_fn(c, val): tr = numpyro.handlers.trace(model).get_trace() assert tr["x"]["value"].shape == (10, 1) assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3) + + +def test_scan_masked(): + def model(y=None, T=10): + def transition(carry, y_curr): + x_prev, t = carry + with numpyro.plate("N", 10, dim=-1): + with mask(mask=(t < T)): + x_curr = numpyro.sample('x', dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)) + y_curr = numpyro.sample('y', dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1), obs=y_curr) + return (x_curr, t + 1), None + x0 = numpyro.sample('x_0', dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)) + + x, t = scan(transition, (x0, 0), y, length=T) + return (x, y) + + with numpyro.handlers.seed(rng_seed=0): + x, y = model() From 3060f66d21fef4ec662798b1882320bbc16f2844 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Thu, 24 Aug 2023 13:35:24 +0200 Subject: [PATCH 07/11] lint with flake8 and black --- test/contrib/test_control_flow.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index dc17c3eb5..bb7d4a167 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -212,16 +212,26 @@ def transition_fn(c, val): assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3) -def test_scan_masked(): +def test_scan_plate_mask(): def model(y=None, T=10): def transition(carry, y_curr): x_prev, t = carry with numpyro.plate("N", 10, dim=-1): with mask(mask=(t < T)): - x_curr = numpyro.sample('x', dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)) - y_curr = numpyro.sample('y', dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1), obs=y_curr) + x_curr = numpyro.sample( + "x", + dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1), + ) + y_curr = numpyro.sample( + "y", + dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1), + obs=y_curr, + ) return (x_curr, t + 1), None - x0 = numpyro.sample('x_0', dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)) + + x0 = numpyro.sample( + "x_0", dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1) + ) x, t = scan(transition, (x0, 0), y, length=T) return (x, y) From 01da9f7e56f2c19ed6fb0ea95406c7b43b854bad Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Thu, 24 Aug 2023 13:43:16 +0200 Subject: [PATCH 08/11] isort --- test/contrib/test_control_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index bb7d4a167..243e2ef2c 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -10,7 +10,7 @@ import numpyro from numpyro.contrib.control_flow import cond, scan import numpyro.distributions as dist -from numpyro.handlers import seed, substitute, trace, mask +from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO from numpyro.infer.util import potential_energy From d283108374e63b44959f01f2706621d90de72b5d Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Fri, 25 Aug 2023 16:27:36 +0200 Subject: [PATCH 09/11] _promote_batch_shape_independent and log_density in test --- numpyro/distributions/batch_util.py | 12 +++++++++--- test/contrib/test_control_flow.py | 5 +++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 253152470..48688ac10 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -564,9 +564,6 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): return new_self -promote_batch_shape.register(Independent, _promote_batch_shape_expanded) - - @promote_batch_shape.register def _promote_batch_shape_masked(d: MaskedDistribution): new_self = copy.copy(d) @@ -576,6 +573,15 @@ def _promote_batch_shape_masked(d: MaskedDistribution): return new_self +@promote_batch_shape.register +def _promote_batch_shape_independent(d: Independent): + new_self = copy.copy(d) + new_base_dist = promote_batch_shape(d.base_dist) + new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim] + new_self.base_dist = new_base_dist + return new_self + + @promote_batch_shape.register def _promote_batch_shape_unit(d: Unit): return d diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 243e2ef2c..b84a0c717 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -12,7 +12,7 @@ import numpyro.distributions as dist from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO -from numpyro.infer.util import potential_energy +from numpyro.infer.util import potential_energy, log_density def test_scan(): @@ -237,4 +237,5 @@ def transition(carry, y_curr): return (x, y) with numpyro.handlers.seed(rng_seed=0): - x, y = model() + model_density, model_trace = log_density(model, (None, 10), {}, {}) + assert model_density From 2ececc84ccf9cf1e9627b7044110b3b50e7e81b5 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Fri, 25 Aug 2023 16:30:04 +0200 Subject: [PATCH 10/11] import sorting --- test/contrib/test_control_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index b84a0c717..95c5057f7 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -12,7 +12,7 @@ import numpyro.distributions as dist from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO -from numpyro.infer.util import potential_energy, log_density +from numpyro.infer.util import log_density, potential_energy def test_scan(): From 58232dfbac9a3aa5a2d397ca04afd2bf153424d5 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Fri, 25 Aug 2023 16:33:58 +0200 Subject: [PATCH 11/11] shape assertions --- test/contrib/test_control_flow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 95c5057f7..6ccc84605 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -239,3 +239,5 @@ def transition(carry, y_curr): with numpyro.handlers.seed(rng_seed=0): model_density, model_trace = log_density(model, (None, 10), {}, {}) assert model_density + assert model_trace["x"]["fn"].batch_shape == (10,) + assert model_trace["x"]["fn"].event_shape == (3,)