From 557bfe00d6cd7a79fa7d97df9869d870e6bc6dc8 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 14:48:26 -0400 Subject: [PATCH 1/2] fix ruff did not format previously --- Makefile | 5 +- examples/hmm_enum.py | 9 +- numpyro/compat/infer.py | 2 +- numpyro/contrib/ecs_proxies.py | 1 - numpyro/contrib/einstein/steinvi.py | 12 +- numpyro/contrib/module.py | 4 +- numpyro/contrib/tfp/distributions.py | 8 +- numpyro/contrib/tfp/mcmc.py | 8 +- numpyro/distributions/conjugate.py | 3 + numpyro/distributions/continuous.py | 20 +- numpyro/distributions/discrete.py | 1 + numpyro/distributions/transforms.py | 8 +- numpyro/distributions/truncated.py | 4 +- numpyro/infer/autoguide.py | 22 ++- numpyro/infer/elbo.py | 22 ++- numpyro/infer/ensemble.py | 201 +++++++++++++------- numpyro/infer/ensemble_util.py | 7 +- numpyro/infer/hmc.py | 8 +- numpyro/infer/hmc_gibbs.py | 2 - numpyro/infer/hmc_util.py | 2 + numpyro/infer/inspect.py | 3 +- numpyro/infer/mixed_hmc.py | 2 +- numpyro/infer/svi.py | 22 ++- numpyro/optim.py | 23 ++- test/contrib/test_funsor.py | 5 +- test/contrib/test_module.py | 24 ++- test/contrib/test_tfp.py | 3 +- test/infer/test_compute_downstream_costs.py | 4 +- test/infer/test_ensemble_mcmc.py | 50 +++-- test/infer/test_ensemble_util.py | 16 +- test/infer/test_mcmc.py | 24 ++- test/infer/test_svi.py | 4 +- test/test_distributions.py | 3 +- test/test_transforms.py | 8 +- 34 files changed, 343 insertions(+), 197 deletions(-) diff --git a/Makefile b/Makefile index d654567f8..500d4f4c8 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,15 @@ all: test lint: FORCE - ruff . + ruff check . + ruff format . --check python scripts/update_headers.py --check license: FORCE python scripts/update_headers.py format: license FORCE - ruff . --fix + ruff format . install: FORCE pip install -e .[dev,doc,test,examples] diff --git a/examples/hmm_enum.py b/examples/hmm_enum.py index d97af66b1..c6108ba37 100644 --- a/examples/hmm_enum.py +++ b/examples/hmm_enum.py @@ -303,8 +303,13 @@ def transition_fn(carry, y): with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] - x_prev, x_curr = x_curr, numpyro.sample( - "x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"} + x_prev, x_curr = ( + x_curr, + numpyro.sample( + "x", + dist.Categorical(probs_x_t), + infer={"enumerate": "parallel"}, + ), ) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] diff --git a/numpyro/compat/infer.py b/numpyro/compat/infer.py index de1c6f926..7da6c4075 100644 --- a/numpyro/compat/infer.py +++ b/numpyro/compat/infer.py @@ -126,7 +126,7 @@ def __init__( loss_and_grads=None, num_samples=10, num_steps=0, - **kwargs + **kwargs, ): super(SVI, self).__init__(model=model, guide=guide, optim=optim, loss=loss) self.svi_state = None diff --git a/numpyro/contrib/ecs_proxies.py b/numpyro/contrib/ecs_proxies.py index cc8834a40..c17b2d167 100644 --- a/numpyro/contrib/ecs_proxies.py +++ b/numpyro/contrib/ecs_proxies.py @@ -185,7 +185,6 @@ def log_likelihood_sum(params_flat, subsample_indices=None): ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat) def gibbs_init(rng_key, gibbs_sites): - ref_subsamples_taylor = [ log_likelihood(ref_params_flat, gibbs_sites), jacobian(log_likelihood)(ref_params_flat, gibbs_sites), diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 0436b10ed..0dbabb381 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -217,13 +217,11 @@ def local_trace(key): def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein - non_mixture_uparams = ( - { # Includes any marked guide parameters and all model parameters - p: v - for p, v in unconstr_params.items() - if p not in self.guide_sites or self.non_mixture_params_fn(p) - } - ) + non_mixture_uparams = { # Includes any marked guide parameters and all model parameters + p: v + for p, v in unconstr_params.items() + if p not in self.guide_sites or self.non_mixture_params_fn(p) + } stein_uparams = { p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams } diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index ae917c18e..f370b4330 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -254,7 +254,7 @@ def random_flax_module( input_shape=None, apply_rng=None, mutable=None, - **kwargs + **kwargs, ): """ A primitive to place a prior over the parameters of the Flax module `nn_module`. @@ -372,7 +372,7 @@ def __call__(self, x): input_shape=input_shape, apply_rng=apply_rng, mutable=mutable, - **kwargs + **kwargs, ) params = nn.args[0] new_params = deepcopy(params) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index afce44f7f..4db875cfe 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -314,9 +314,7 @@ def kl_divergence(p, q): # noqa: F811 _PyroDist.__doc__ = """ Wraps `{}.{} `_ with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`. - """.format( - _Dist.__module__, _Dist.__name__, _Dist.__name__ - ) + """.format(_Dist.__module__, _Dist.__name__, _Dist.__name__) __all__.append(_name) @@ -328,9 +326,7 @@ def kl_divergence(p, q): # noqa: F811 {0} ---------------------------------------------------------------- .. autoclass:: numpyro.contrib.tfp.distributions.{0} - """.format( - _name - ) + """.format(_name) for _name in __all__[:_len_all] ] ) diff --git a/numpyro/contrib/tfp/mcmc.py b/numpyro/contrib/tfp/mcmc.py index bce3b8da7..b660af837 100644 --- a/numpyro/contrib/tfp/mcmc.py +++ b/numpyro/contrib/tfp/mcmc.py @@ -236,9 +236,7 @@ def sample(self, state, model_args, model_kwargs): Wraps `{}.{} `_ with :class:`~numpyro.contrib.tfp.mcmc.TFPKernel`. The first argument `target_log_prob_fn` in TFP kernel construction is replaced by either `model` or `potential_fn`. - """.format( - _Kernel.__module__, _Kernel.__name__, _Kernel.__name__ - ) + """.format(_Kernel.__module__, _Kernel.__name__, _Kernel.__name__) __all__.append(_name) @@ -250,9 +248,7 @@ def sample(self, state, model_args, model_kwargs): {0} ---------------------------------------------------------------- .. autoclass:: numpyro.contrib.tfp.mcmc.{0} - """.format( - _name - ) + """.format(_name) for _name in __all__[:1] + sorted(__all__[1:]) ] ) diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index f0c7b93c7..d3d364da9 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -36,6 +36,7 @@ class BetaBinomial(Distribution): Beta distribution. :param numpy.ndarray total_count: number of Bernoulli trials. """ + arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -107,6 +108,7 @@ class DirichletMultinomial(Distribution): Dirichlet distribution. :param numpy.ndarray total_count: number of Categorical trials. """ + arg_constraints = { "concentration": constraints.independent(constraints.positive, 1), "total_count": constraints.nonnegative_integer, @@ -182,6 +184,7 @@ class GammaPoisson(Distribution): :param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution. :param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution. """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7d1ee97a0..861eddf1b 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -284,9 +284,7 @@ def mean(self): @property def variance(self): con0 = jnp.sum(self.concentration, axis=-1, keepdims=True) - return ( - self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1)) - ) + return self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1)) @staticmethod def infer_shapes(concentration): @@ -909,6 +907,7 @@ def model(y): # y has dimension N x d [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ + arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_matrix @@ -985,6 +984,7 @@ def model(y): # y has dimension N x d [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ + arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_cholesky @@ -1961,9 +1961,10 @@ def scale_tril(self): @lazy_property def covariance_matrix(self): - covariance_matrix = add_diag(jnp.matmul( - self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2) - ), self.cov_diag) + covariance_matrix = add_diag( + jnp.matmul(self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)), + self.cov_diag, + ) return covariance_matrix @lazy_property @@ -1976,7 +1977,7 @@ def precision_matrix(self): ) A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True) inverse_cov_diag = jnp.reciprocal(self.cov_diag) - return add_diag(- jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag) + return add_diag(-jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag) def sample(self, key, sample_shape=()): assert is_prng_key(key) @@ -2068,8 +2069,9 @@ class Pareto(TransformedDistribution): def __init__(self, scale, alpha, *, validate_args=None): self.scale, self.alpha = promote_shapes(scale, alpha) batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha)) - scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to( - alpha, batch_shape + scale, alpha = ( + jnp.broadcast_to(scale, batch_shape), + jnp.broadcast_to(alpha, batch_shape), ) base_dist = Exponential(alpha) transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index e5fbb2f88..a5fd12536 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -679,6 +679,7 @@ class Poisson(Distribution): :param bool is_sparse: Whether to assume value is mostly zero when computing :meth:`log_prob`, which can speed up computation when data is sparse. """ + arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative_integer pytree_aux_fields = ("is_sparse",) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index a908a7eff..55c6f40b7 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -400,6 +400,7 @@ class CholeskyTransform(ParameterFreeTransform): Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. """ + domain = constraints.positive_definite codomain = constraints.lower_cholesky @@ -444,6 +445,7 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a c. Applies :math:`s_i = StickBreakingTransform(z_i)`. d. Transforms back into signed domain: :math:`y_i = (sign(r_i), 1) * \sqrt{s_i}`. """ + domain = constraints.real_vector codomain = constraints.corr_cholesky @@ -493,6 +495,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): Transform via the mapping :math:`y = cholesky(x)`, where `x` is a correlation matrix. """ + domain = constraints.corr_matrix codomain = constraints.corr_cholesky @@ -624,6 +627,7 @@ class L1BallTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` into the unit L1 ball. """ + domain = constraints.real_vector codomain = constraints.l1_ball @@ -687,6 +691,7 @@ class LowerCholeskyAffine(Transform): >>> affine(base) Array([0.3, 1.5], dtype=float32) """ + domain = constraints.real_vector codomain = constraints.real_vector @@ -786,6 +791,7 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): and :math:`scale\_diag` is a diagonal matrix with all positive entries that is parameterized with a softplus transform. """ + domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky @@ -995,6 +1001,7 @@ class SoftplusTransform(ParameterFreeTransform): Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. """ + domain = constraints.real codomain = constraints.softplus_positive @@ -1200,7 +1207,6 @@ def __eq__(self, other): ) - ########################################################## # CONSTRAINT_REGISTRY ########################################################## diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 53edf234a..078ea83bc 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -263,9 +263,7 @@ def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) - return ( - self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 - ) + return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 677b13a81..351cb6390 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -1858,7 +1858,7 @@ def _setup_prototype(self, *args, **kwargs): f"Expected {self.batch_ndim} batch dimensions, but site " f"`{site['name']}` only has shape {shape}." ) - shape = shape[:self.batch_ndim] + shape = shape[: self.batch_ndim] if batch_shape is None: batch_shape = shape elif shape != batch_shape: @@ -1884,7 +1884,9 @@ def _get_posterior(self): ) def _get_reshape_transform(self) -> ReshapeTransform: - return ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape) + return ReshapeTransform( + (self.latent_dim,), self._batch_shape + self._event_shape + ) class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous): @@ -1914,7 +1916,10 @@ def __init__( raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale super().__init__( - model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim, + model, + prefix=prefix, + init_loc_fn=init_loc_fn, + batch_ndim=batch_ndim, ) def _get_batched_posterior(self): @@ -2044,16 +2049,21 @@ def __init__( self._init_scale = init_scale self.rank = rank super().__init__( - model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim, + model, + prefix=prefix, + init_loc_fn=init_loc_fn, + batch_ndim=batch_ndim, ) def _get_batched_posterior(self): - rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank + rank = ( + int(round(self._event_shape[0] ** 0.5)) if self.rank is None else self.rank + ) init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape) loc = numpyro.param("{}_loc".format(self.prefix), init_latent) cov_factor = numpyro.param( "{}_cov_factor".format(self.prefix), - jnp.zeros(self._batch_shape + self._event_shape + (rank,)) + jnp.zeros(self._batch_shape + self._event_shape + (rank,)), ) scale = numpyro.param( "{}_scale".format(self.prefix), diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index eb4e11b68..5a7e240f7 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -261,16 +261,18 @@ def _check_mean_field_requirement(model_trace, guide_trace): ] assert set(model_sites) == set(guide_sites) if model_sites != guide_sites: - warnings.warn( - "Failed to verify mean field restriction on the guide. " - "To eliminate this warning, ensure model and guide sites " - "occur in the same order.\n" - + "Model sites:\n " - + "\n ".join(model_sites) - + "Guide sites:\n " - + "\n ".join(guide_sites), - stacklevel=find_stack_level(), - ), + ( + warnings.warn( + "Failed to verify mean field restriction on the guide. " + "To eliminate this warning, ensure model and guide sites " + "occur in the same order.\n" + + "Model sites:\n " + + "\n ".join(model_sites) + + "Guide sites:\n " + + "\n ".join(guide_sites), + stacklevel=find_stack_level(), + ), + ) class TraceMeanField_ELBO(ELBO): diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 5d7a9c119..7b0e1e5ef 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -41,13 +41,9 @@ - **rng_key** - random number generator seed used for generating proposals, etc. """ -ESSState = namedtuple("ESSState", ["i", - "n_expansions", - "n_contractions", - "mu", - "rng_key" - ] - ) +ESSState = namedtuple( + "ESSState", ["i", "n_expansions", "n_contractions", "mu", "rng_key"] +) """ A :func:`~collections.namedtuple` used as an inner state for Ensemble Sampler. This consists of the following fields: @@ -76,7 +72,9 @@ class EnsembleSampler(MCMCKernel, ABC): See :ref:`init_strategy` section for available functions. """ - def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strategy): + def __init__( + self, model=None, potential_fn=None, *, randomize_split, init_strategy + ): if not (model is None) ^ (potential_fn is None): raise ValueError("Only one of `model` or `potential_fn` must be specified.") @@ -118,7 +116,12 @@ def update_active_chains(self, active, inactive, inner_state): def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: - new_params_info, potential_fn_gen, self._postprocess_fn, _ = initialize_model( + ( + new_params_info, + potential_fn_gen, + self._postprocess_fn, + _, + ) = initialize_model( rng_key, self._model, dynamic_args=True, @@ -137,18 +140,20 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) if self._num_chains < 2 * flat_params.shape[1]: - warnings.warn("Setting n_chains to at least 2*n_params is strongly recommended.\n" - f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}") + warnings.warn( + "Setting n_chains to at least 2*n_params is strongly recommended.\n" + f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}" + ) return init_params def init( self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} ): - assert not is_prng_key( - rng_key - ), ("EnsembleSampler only supports chain_method='vectorized' with num_chains > 1.\n" - "If you want to run chains in parallel, please raise a github issue.") + assert not is_prng_key(rng_key), ( + "EnsembleSampler only supports chain_method='vectorized' with num_chains > 1.\n" + "If you want to run chains in parallel, please raise a github issue." + ) assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." @@ -159,9 +164,12 @@ def init( "Valid value of `init_params` must be provided with `potential_fn`." ) if init_params is not None: - assert all([param.shape[0] == self._num_chains - for param in jax.tree_util.tree_leaves(init_params)]), ( - "The batch dimension of each param must match n_chains") + assert all( + [ + param.shape[0] == self._num_chains + for param in jax.tree_util.tree_leaves(init_params) + ] + ), "The batch dimension of each param must match n_chains" rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) rng_key_init_model = random.split(rng_key_init_model, self._num_chains) @@ -194,22 +202,27 @@ def sample(self, state, model_args, model_kwargs): def body_fn(i, z_flat_inner_state): z_flat, inner_state = z_flat_inner_state - active, inactive = jax.lax.cond(i == 0, - lambda x: (x[:split_ind], x[split_ind:]), - lambda x: (x[split_ind:], x[split_ind:]), - z_flat) + active, inactive = jax.lax.cond( + i == 0, + lambda x: (x[:split_ind], x[split_ind:]), + lambda x: (x[split_ind:], x[split_ind:]), + z_flat, + ) - z_updates, inner_state = self.update_active_chains(active, inactive, inner_state) + z_updates, inner_state = self.update_active_chains( + active, inactive, inner_state + ) - z_flat = jax.lax.cond(i == 0, - lambda x: x.at[:split_ind].set(z_updates), - lambda x: x.at[split_ind:].set(z_updates), - z_flat) + z_flat = jax.lax.cond( + i == 0, + lambda x: x.at[:split_ind].set(z_updates), + lambda x: x.at[split_ind:].set(z_updates), + z_flat, + ) return (z_flat, inner_state) z_flat, inner_state = jax.lax.fori_loop(0, 2, body_fn, (z_flat, inner_state)) - return EnsembleSamplerState(unravel_fn(z_flat), inner_state, rng_key) @@ -261,30 +274,46 @@ class AIES(EnsembleSampler): >>> mcmc.run(jax.random.PRNGKey(0)) """ - def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform): + def __init__( + self, + model=None, + potential_fn=None, + randomize_split=False, + moves=None, + init_strategy=init_to_uniform, + ): if not moves: self._moves = [AIES.DEMove()] self._weights = jnp.array([1.0]) else: self._moves = list(moves.keys()) - self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) - - assert all([hasattr(move, '__call__') for move in self._moves]), ( - "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove()).") - assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + self._weights = jnp.array([weight for weight in moves.values()]) / len( + moves + ) - super().__init__(model, - potential_fn, - randomize_split=randomize_split, - init_strategy=init_strategy) + assert all( + [hasattr(move, "__call__") for move in self._moves] + ), "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove())." + assert jnp.all( + self._weights >= 0 + ), "Each specified move must have probability >= 0" + + super().__init__( + model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy, + ) def get_diagnostics_str(self, state): return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob) def init_inner_state(self, rng_key): # XXX hack -- we don't know num_chains until we init the inner state - self._moves = [move(self._num_chains) if move.__name__ == 'make_de_move' - else move for move in self._moves] + self._moves = [ + move(self._num_chains) if move.__name__ == "make_de_move" else move + for move in self._moves + ] return AIESState(jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), rng_key) @@ -335,6 +364,7 @@ def DEMove(sigma=1.0e-5, g0=None): The mean stretch factor for the proposal vector. By default, it is `2.38 / sqrt(2*ndim)` as recommended by the two references. """ + def make_de_move(n_chains): PAIRS = get_nondiagonal_indices(n_chains // 2) @@ -346,7 +376,9 @@ def de_move(rng_key, active, inactive): # recompute this each time g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0 - selected_pairs = random.choice(pairs_key, PAIRS, shape=(n_active_chains,)) + selected_pairs = random.choice( + pairs_key, PAIRS, shape=(n_active_chains,) + ) # Compute diff vectors diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1) @@ -378,6 +410,7 @@ def StretchMove(a=2.0): :param a: (optional) The stretch scale parameter. (default: ``2.0``) """ + def stretch_move(rng_key, active, inactive): n_active_chains, n_params = active.shape unif_key, idx_key = random.split(rng_key) @@ -390,7 +423,9 @@ def stretch_move(rng_key, active, inactive): idx_key, shape=(n_active_chains,), minval=0, maxval=n_active_chains ) - proposal = inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis] + proposal = ( + inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis] + ) return proposal, factors @@ -455,6 +490,7 @@ class ESS(EnsembleSampler): >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') >>> mcmc.run(jax.random.PRNGKey(0)) """ + def __init__( self, model=None, @@ -472,13 +508,18 @@ def __init__( self._weights = jnp.array([1.0]) else: self._moves = list(moves.keys()) - self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + self._weights = jnp.array([weight for weight in moves.values()]) / len( + moves + ) - assert all([hasattr(move, '__call__') for move in self._moves]), ( + assert all([hasattr(move, "__call__") for move in self._moves]), ( "Each move must be a callable (one of `ESS.DifferentialMove()`, " - "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)") + "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)" + ) - assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + assert jnp.all( + self._weights >= 0 + ), "Each specified move must have probability >= 0" assert init_mu > 0, "Scale factor should be strictly positive" self._max_steps = max_steps # max number of stepping out steps @@ -486,28 +527,38 @@ def __init__( self._init_mu = init_mu self._tune_mu = tune_mu - super().__init__(model, - potential_fn, - randomize_split=randomize_split, - init_strategy=init_strategy) + super().__init__( + model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy, + ) def init_inner_state(self, rng_key): self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis] # XXX hack -- we don't know num_chains until we init the inner state - self._moves = [move(self._num_chains) if move.__name__ == 'make_differential_move' - else move for move in self._moves] - - return ESSState(jnp.array(0.0), jnp.array(0), jnp.array(0), self._init_mu, rng_key) + self._moves = [ + move(self._num_chains) + if move.__name__ == "make_differential_move" + else move + for move in self._moves + ] + + return ESSState( + jnp.array(0.0), jnp.array(0), jnp.array(0), self._init_mu, rng_key + ) def update_active_chains(self, active, inactive, inner_state): i, n_expansions, n_contractions, mu, rng_key = inner_state - (rng_key, - move_key, - dir_key, - height_key, - step_out_key, - shrink_key) = random.split(rng_key, 6) + ( + rng_key, + move_key, + dir_key, + height_key, + step_out_key, + shrink_key, + ) = random.split(rng_key, 6) n_active_chains, n_params = active.shape @@ -533,20 +584,20 @@ def update_active_chains(self, active, inactive, inner_state): safe_n_expansions = jnp.max(jnp.array([1, n_expansions])) # only update tuning scale if a full iteration has passed - mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0), - lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con), - jnp.array(0), - jnp.array(0) - ), - lambda _, __: (mu, - n_expansions, - n_contractions - ), - safe_n_expansions, n_contractions) + mu, n_expansions, n_contractions = jax.lax.cond( + jnp.all(itr % 1 == 0), + lambda n_exp, n_con: ( + 2.0 * n_exp / (n_exp + n_con), + jnp.array(0), + jnp.array(0), + ), + lambda _, __: (mu, n_expansions, n_contractions), + safe_n_expansions, + n_contractions, + ) return proposal, ESSState(itr, n_expansions, n_contractions, mu, rng_key) - @staticmethod def RandomMove(): """ @@ -555,6 +606,7 @@ def RandomMove(): walkers and this Move corresponds to the vanilla Slice Sampling method. This Move should be used for debugging purposes only. """ + def random_move(rng_key, inactive, mu): directions = dist.Normal(loc=0, scale=1).sample( rng_key, sample_shape=inactive.shape @@ -562,6 +614,7 @@ def random_move(rng_key, inactive, mu): directions /= jnp.linalg.norm(directions, axis=0) return 2.0 * mu * directions + return random_move @staticmethod @@ -572,6 +625,7 @@ def KDEMove(bw_method=None): a Gaussian Kernel Density Estimation methods. The walkers then move along random direction vectos sampled from this distribution. """ + def kde_move(rng_key, inactive, mu): n_active_chains, n_params = inactive.shape @@ -581,6 +635,7 @@ def kde_move(rng_key, inactive, mu): directions = vectors[:n_active_chains] - vectors[n_active_chains:] return 2.0 * mu * directions + return kde_move @staticmethod @@ -609,6 +664,7 @@ def gaussian_move(rng_key, inactive, mu): rng_key, sample_shape=(n_active_chains,) ) ) + return gaussian_move @staticmethod @@ -619,6 +675,7 @@ def DifferentialMove(): replacement) from the complementary ensemble. This is the default choice and performs well along a wide range of target distributions. """ + def make_differential_move(n_chains): PAIRS = get_nondiagonal_indices(n_chains // 2) @@ -631,11 +688,11 @@ def differential_move(rng_key, inactive, mu): ) # get the pairwise difference of each vector return 2.0 * mu * diffs + return differential_move return make_differential_move - def _step_out(self, rng_key, log_slice_height, active, directions): init_L_key, init_J_key = random.split(rng_key) n_active_chains, n_params = active.shape diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index 028d694a5..9f213ea5c 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -18,8 +18,9 @@ def get_nondiagonal_indices(n): rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal # Combine rows-cols and cols-rows pairs - pairs = np.column_stack([np.concatenate([rows, cols]), - np.concatenate([cols, rows])]) + pairs = np.column_stack( + [np.concatenate([rows, cols]), np.concatenate([cols, rows])] + ) return jnp.asarray(pairs) @@ -43,5 +44,3 @@ def batch_ravel_pytree(pytree): unravel_fn = jax.vmap(ravel_pytree(tree_map(lambda z: z[0], pytree))[1]) return flat, unravel_fn - - diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index aa2fcd802..709c16824 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -287,7 +287,13 @@ def init_kernel( trajectory_length = lax.convert_element_type( trajectory_length, jnp.result_type(float) ) - nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad, fixed_num_steps + nonlocal \ + wa_update, \ + max_treedepth, \ + vv_update, \ + wa_steps, \ + forward_mode_ad, \ + fixed_num_steps forward_mode_ad = forward_mode_differentiation wa_steps = num_warmup max_treedepth = ( diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 5b7fb7438..53622dcf2 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -687,8 +687,6 @@ def taylor_proxy(reference_params, degree=2): return taylor_proxy(reference_params, degree) - - class estimate_likelihood(numpyro.primitives.Messenger): def __init__(self, fn=None, method=None): # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index 07f36a81a..51e628148 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -241,9 +241,11 @@ def final_fn(state, regularize=False): def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: + def _wrapper(x): out = f(x) return out, out + grads, out = jacfwd(_wrapper, has_aux=True)(x) return out, grads else: diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 5232dfe11..abfb9bd71 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -385,8 +385,7 @@ def process_message(self, msg): samples = { name: site["value"] for name, site in trace.items() - if site["type"] == "sample" - or site["type"] == "deterministic" + if site["type"] == "sample" or site["type"] == "deterministic" } params = { diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 726d18c4c..deea69a79 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -74,7 +74,7 @@ def __init__( *, num_discrete_updates=None, random_walk=False, - modified=False + modified=False, ): super().__init__(inner_kernel, random_walk=random_walk, modified=modified) if inner_kernel._algo == "NUTS": diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 77f05255f..b21b531de 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -283,11 +283,15 @@ def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs) mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( - loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation + loss_fn, + svi_state.optim_state, + forward_mode_differentiation=forward_mode_differentiation, ) return SVIState(optim_state, mutable_state, rng_key), loss_val - def stable_update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs): + def stable_update( + self, svi_state, *args, forward_mode_differentiation=False, **kwargs + ): """ Similar to :meth:`update` but returns the current state if the the loss or the new state contains invalid values. @@ -314,7 +318,9 @@ def stable_update(self, svi_state, *args, forward_mode_differentiation=False, ** mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( - loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation + loss_fn, + svi_state.optim_state, + forward_mode_differentiation=forward_mode_differentiation, ) return SVIState(optim_state, mutable_state, rng_key), loss_val @@ -378,11 +384,17 @@ def run( def body_fn(svi_state, _): if stable_update: svi_state, loss = self.stable_update( - svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + svi_state, + *args, + forward_mode_differentiation=forward_mode_differentiation, + **kwargs, ) else: svi_state, loss = self.update( - svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + svi_state, + *args, + forward_mode_differentiation=forward_mode_differentiation, + **kwargs, ) return svi_state, loss diff --git a/numpyro/optim.py b/numpyro/optim.py index 5ad451bab..225a6f6bb 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -34,16 +34,20 @@ _OptState = TypeVar("_OptState") _IterOptState = tuple[int, _OptState] + def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: + def _wrapper(x): out, aux = f(x) return out, (out, aux) + grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x) return (out, aux), grads else: return value_and_grad(f, has_aux=True)(x) + class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs) @@ -71,7 +75,10 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: return i + 1, opt_state def eval_and_update( - self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False + self, + fn: Callable[[Any], tuple], + state: _IterOptState, + forward_mode_differentiation: bool = False, ): """ Performs an optimization step for the objective function `fn`. @@ -95,8 +102,11 @@ def eval_and_update( return (out, aux), self.update(grads, state) def eval_and_stable_update( - self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False - ): + self, + fn: Callable[[Any], tuple], + state: _IterOptState, + forward_mode_differentiation: bool = False, + ): """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -286,8 +296,11 @@ def __init__(self, method="BFGS", **kwargs): self._kwargs = kwargs def eval_and_update( - self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation=False - ): + self, + fn: Callable[[Any], tuple], + state: _IterOptState, + forward_mode_differentiation=False, + ): i, (flat_params, unravel_fn) = state def loss_fn(x): diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 30427b960..ad6037f17 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -485,8 +485,9 @@ def model(): x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] - x_prev, x_curr = x_curr, numpyro.sample( - "x_{}".format(t), dist.Bernoulli(probs) + x_prev, x_curr = ( + x_curr, + numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)), ) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 803af666c..5f43dcb3e 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -23,7 +23,9 @@ import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS -pytestmark = pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning") +pytestmark = pytest.mark.filterwarnings( + "ignore:jax.tree_.+ is deprecated:FutureWarning" +) def haiku_model_by_shape(x, y): @@ -117,12 +119,16 @@ def test_haiku_module(): 100, 100, ) - assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == (100,) + assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == ( + 100, + ) assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["w"].shape == ( 100, 100, ) - assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == (100,) + assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == ( + 100, + ) def test_update_params(): @@ -131,7 +137,9 @@ def test_update_params(): new_params = deepcopy(params) with handlers.seed(rng_seed=0): _update_params(params, new_params, prior) - assert params == {"a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))}} + assert params == { + "a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))} + } tree_all( tree_map( @@ -198,7 +206,9 @@ def model(data, labels): numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) - mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -222,7 +232,9 @@ def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: - x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(x, is_training=True) + x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( + x, is_training=True + ) return x def model(): diff --git a/test/contrib/test_tfp.py b/test/contrib/test_tfp.py index 15c72d8d5..ab3adf64c 100644 --- a/test/contrib/test_tfp.py +++ b/test/contrib/test_tfp.py @@ -280,7 +280,8 @@ def test_sample_unwrapped_mixture_same_family(): tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]), components_distribution=tfd.Normal( - loc=[-1.0, 1], scale=[0.1, 0.5] # One for each component. + loc=[-1.0, 1], + scale=[0.1, 0.5], # One for each component. ), ), ) diff --git a/test/infer/test_compute_downstream_costs.py b/test/infer/test_compute_downstream_costs.py index 700e11d84..170785928 100644 --- a/test/infer/test_compute_downstream_costs.py +++ b/test/infer/test_compute_downstream_costs.py @@ -24,7 +24,9 @@ def _brute_force_compute_downstream_costs( - model_trace, guide_trace, non_reparam_nodes # + model_trace, + guide_trace, + non_reparam_nodes, # ): model_successors = _identify_dense_edges(model_trace) guide_successors = _identify_dense_edges(guide_trace) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py index 7cc289d89..02c065aa4 100644 --- a/test/infer/test_ensemble_mcmc.py +++ b/test/infer/test_ensemble_mcmc.py @@ -20,42 +20,62 @@ logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) + def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + # --- -@pytest.mark.parametrize("kernel_cls, n_chain, method", - [(AIES, 10, "sequential"), - (AIES, 1, "vectorized"), - (AIES, 2, "parallel"), - (ESS, 10, "sequential"), - (ESS, 1, "vectorized"), - (ESS, 2, "parallel")]) + +@pytest.mark.parametrize( + "kernel_cls, n_chain, method", + [ + (AIES, 10, "sequential"), + (AIES, 1, "vectorized"), + (AIES, 2, "parallel"), + (ESS, 10, "sequential"), + (ESS, 1, "vectorized"), + (ESS, 2, "parallel"), + ], +) def test_chain_smoke(kernel_cls, n_chain, method): kernel = kernel_cls(model) - mcmc = MCMC(kernel, num_warmup=10, num_samples=10, - progress_bar=False, num_chains=n_chain, chain_method=method) + mcmc = MCMC( + kernel, + num_warmup=10, + num_samples=10, + progress_bar=False, + num_chains=n_chain, + chain_method=method, + ) with pytest.raises(AssertionError, match="chain_method"): mcmc.run(random.PRNGKey(2), labels) + @pytest.mark.parametrize("kernel_cls", [AIES, ESS]) def test_out_shape_smoke(kernel_cls): n_chains = 10 kernel = kernel_cls(model) - mcmc = MCMC(kernel, num_warmup=10, num_samples=10, - progress_bar=False, num_chains=n_chains, chain_method='vectorized') + mcmc = MCMC( + kernel, + num_warmup=10, + num_samples=10, + progress_bar=False, + num_chains=n_chains, + chain_method="vectorized", + ) mcmc.run(random.PRNGKey(2), labels) - assert (mcmc.get_samples(group_by_chain=True)['coefs'].shape[0] - == n_chains) + assert mcmc.get_samples(group_by_chain=True)["coefs"].shape[0] == n_chains + @pytest.mark.parametrize("kernel_cls", [AIES, ESS]) def test_invalid_moves(kernel_cls): with pytest.raises(AssertionError, match="Each move"): - kernel_cls(model, moves={'invalid': 1.}) - + kernel_cls(model, moves={"invalid": 1.0}) diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py index 5a066c69d..ad28a76ad 100644 --- a/test/infer/test_ensemble_util.py +++ b/test/infer/test_ensemble_util.py @@ -8,28 +8,24 @@ def test_nondiagonal_indices(): - truth = jnp.array( - [[1, 0], - [2, 0], - [2, 1], - [0, 1], - [0, 2], - [1, 2]], dtype=jnp.int32) + truth = jnp.array([[1, 0], [2, 0], [2, 1], [0, 1], [0, 2], [1, 2]], dtype=jnp.int32) assert jnp.all(get_nondiagonal_indices(3) == truth) + def test_batch_ravel_pytree(): arr1 = jnp.arange(10).reshape((5, 2)) arr2 = jnp.arange(15).reshape((5, 3)) arr3 = jnp.arange(20).reshape((5, 4)) - tree = {'arr1': arr1, 'arr2': arr2, 'arr3': arr3} + tree = {"arr1": arr1, "arr2": arr2, "arr3": arr3} flattened, unravel_fn = batch_ravel_pytree(tree) unflattened = unravel_fn(flattened) assert flattened.shape == (5, 2 + 3 + 4) - for unflattened_leaf, original_leaf in zip(jax.tree_util.tree_leaves(unflattened), - jax.tree_util.tree_leaves(tree)): + for unflattened_leaf, original_leaf in zip( + jax.tree_util.tree_leaves(unflattened), jax.tree_util.tree_leaves(tree) + ): assert jnp.all(unflattened_leaf == original_leaf) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 3f52c6c1e..fa1fea460 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -40,8 +40,12 @@ def potential_fn(z): init_params = random.normal(random.PRNGKey(1), (num_chains,)) mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, - num_chains=num_chains, chain_method='vectorized' + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + progress_bar=False, + num_chains=num_chains, + chain_method="vectorized", ) elif kernel_cls in [SA, BarkerMH]: kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) @@ -124,8 +128,12 @@ def model(labels): kernel = kernel_cls(model) mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=samples_each_chain, - progress_bar=False, num_chains=num_chains, chain_method='vectorized' + kernel, + num_warmup=num_warmup, + num_samples=samples_each_chain, + progress_bar=False, + num_chains=num_chains, + chain_method="vectorized", ) elif kernel_cls is SA: num_warmup, num_samples = (100000, 100000) @@ -242,8 +250,12 @@ def model(data): num_chains = 10 kernel = kernel_cls(model=model) mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, - progress_bar=False, num_chains=num_chains, chain_method='vectorized' + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + progress_bar=False, + num_chains=num_chains, + chain_method="vectorized", ) elif kernel_cls is SA: kernel = SA(model=model) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 65a1751ff..166499063 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -766,8 +766,8 @@ def model(): numpyro.sample("obs", dist.Normal(y, 1), obs=1.0) def guide(): - loc = numpyro.param("loc", 0.) - scale = numpyro.param("scale", 1., constraint=dist.constraints.positive) + loc = numpyro.param("loc", 0.0) + scale = numpyro.param("scale", 1.0, constraint=dist.constraints.positive) numpyro.sample("x", dist.Normal(loc, scale)) # this fails in reverse mode diff --git a/test/test_distributions.py b/test/test_distributions.py index 462a461a0..0119aa1b8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3061,7 +3061,8 @@ def sample(d: dist.Distribution): vmap_over(d, **{param_names[idx]: 1}), ) for idx in vmappable_param_idxs - if isinstance(params[idx], jnp.ndarray) and jnp.array(params[idx]).ndim > 0 + if isinstance(params[idx], jnp.ndarray) + and jnp.array(params[idx]).ndim > 0 # skip this distribution because _GeneralMixture.__init__ turns # 1d inputs into 0d attributes, thus breaks the expectations of # the vmapping test case where in_axes=1, only done for rank>=1 tensors. diff --git a/test/test_transforms.py b/test/test_transforms.py index 901f5e1a5..b08c8606a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -118,9 +118,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): dict(), ), "reshape": T( - ReshapeTransform, - (), - {"forward_shape": (3, 4), "inverse_shape": (4, 3)} + ReshapeTransform, (), {"forward_shape": (3, 4), "inverse_shape": (4, 3)} ), } @@ -211,8 +209,8 @@ def check_transforms(t1, t2): ((3, 4), (4, 3), ()), ((7,), (7, 1), ()), ((3, 5), (15,), ()), - ((2, 4), (2, 2, 2), (17,)) - ] + ((2, 4), (2, 2, 2), (17,)), + ], ) def test_reshape_transform(forward_shape, inverse_shape, batch_shape): x = random.normal(random.key(29), batch_shape + inverse_shape) From bfd5343049aa26406a831e8c077b0e6612c7a4f4 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 14:49:16 -0400 Subject: [PATCH 2/2] further ruff format --- examples/hsgp.py | 1 + examples/prodlda.py | 1 + examples/proportion_test.py | 1 - examples/stein_dmm.py | 1 + 4 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/hsgp.py b/examples/hsgp.py index 1cff5f69e..5b074f773 100644 --- a/examples/hsgp.py +++ b/examples/hsgp.py @@ -46,6 +46,7 @@ """ + import argparse import os diff --git a/examples/prodlda.py b/examples/prodlda.py index a0b1561fa..40d6444ca 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -30,6 +30,7 @@ .. image:: ../_static/img/examples/prodlda.png :align: center """ + import argparse import matplotlib.pyplot as plt diff --git a/examples/proportion_test.py b/examples/proportion_test.py index b185acbd1..8ca97d898 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -16,7 +16,6 @@ density interval for the effect of making a call. """ - import argparse import os diff --git a/examples/stein_dmm.py b/examples/stein_dmm.py index 140e886aa..2daeb055b 100644 --- a/examples/stein_dmm.py +++ b/examples/stein_dmm.py @@ -18,6 +18,7 @@ .. image:: ../_static/img/examples/stein_dmm.png :align: center """ + import argparse import numpy as np