Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ruff not format the changes #1761

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
all: test

lint: FORCE
ruff .
ruff check .
ruff format . --check
python scripts/update_headers.py --check

license: FORCE
python scripts/update_headers.py

format: license FORCE
ruff . --fix
ruff format .

install: FORCE
pip install -e .[dev,doc,test,examples]
Expand Down
9 changes: 7 additions & 2 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,13 @@ def transition_fn(carry, y):
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
probs_x_t = Vindex(probs_x)[x_prev, x_curr]
x_prev, x_curr = x_curr, numpyro.sample(
"x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}
x_prev, x_curr = (
x_curr,
numpyro.sample(
"x",
dist.Categorical(probs_x_t),
infer={"enumerate": "parallel"},
),
)
with numpyro.plate("tones", data_dim, dim=-1):
probs_y_t = probs_y[x_curr.squeeze(-1)]
Expand Down
1 change: 1 addition & 0 deletions examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"""

import argparse
import os

Expand Down
1 change: 1 addition & 0 deletions examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
.. image:: ../_static/img/examples/prodlda.png
:align: center
"""

import argparse

import matplotlib.pyplot as plt
Expand Down
1 change: 0 additions & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
density interval for the effect of making a call.
"""


import argparse
import os

Expand Down
1 change: 1 addition & 0 deletions examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
.. image:: ../_static/img/examples/stein_dmm.png
:align: center
"""

import argparse

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion numpyro/compat/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
loss_and_grads=None,
num_samples=10,
num_steps=0,
**kwargs
**kwargs,
):
super(SVI, self).__init__(model=model, guide=guide, optim=optim, loss=loss)
self.svi_state = None
Expand Down
1 change: 0 additions & 1 deletion numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def log_likelihood_sum(params_flat, subsample_indices=None):
ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat)

def gibbs_init(rng_key, gibbs_sites):

ref_subsamples_taylor = [
log_likelihood(ref_params_flat, gibbs_sites),
jacobian(log_likelihood)(ref_params_flat, gibbs_sites),
Expand Down
12 changes: 5 additions & 7 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,11 @@ def local_trace(key):

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
non_mixture_uparams = (
{ # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
)
non_mixture_uparams = { # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
}
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def random_flax_module(
input_shape=None,
apply_rng=None,
mutable=None,
**kwargs
**kwargs,
):
"""
A primitive to place a prior over the parameters of the Flax module `nn_module`.
Expand Down Expand Up @@ -372,7 +372,7 @@ def __call__(self, x):
input_shape=input_shape,
apply_rng=apply_rng,
mutable=mutable,
**kwargs
**kwargs,
)
params = nn.args[0]
new_params = deepcopy(params)
Expand Down
8 changes: 2 additions & 6 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,7 @@ def kl_divergence(p, q): # noqa: F811
_PyroDist.__doc__ = """
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/{}>`_
with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`.
""".format(
_Dist.__module__, _Dist.__name__, _Dist.__name__
)
""".format(_Dist.__module__, _Dist.__name__, _Dist.__name__)

__all__.append(_name)

Expand All @@ -328,9 +326,7 @@ def kl_divergence(p, q): # noqa: F811
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.distributions.{0}
""".format(
_name
)
""".format(_name)
for _name in __all__[:_len_all]
]
)
8 changes: 2 additions & 6 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def sample(self, state, model_args, model_kwargs):
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/{}>`_
with :class:`~numpyro.contrib.tfp.mcmc.TFPKernel`. The first argument `target_log_prob_fn`
in TFP kernel construction is replaced by either `model` or `potential_fn`.
""".format(
_Kernel.__module__, _Kernel.__name__, _Kernel.__name__
)
""".format(_Kernel.__module__, _Kernel.__name__, _Kernel.__name__)

__all__.append(_name)

Expand All @@ -250,9 +248,7 @@ def sample(self, state, model_args, model_kwargs):
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.mcmc.{0}
""".format(
_name
)
""".format(_name)
for _name in __all__[:1] + sorted(__all__[1:])
]
)
3 changes: 3 additions & 0 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BetaBinomial(Distribution):
Beta distribution.
:param numpy.ndarray total_count: number of Bernoulli trials.
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand Down Expand Up @@ -107,6 +108,7 @@ class DirichletMultinomial(Distribution):
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
Expand Down Expand Up @@ -182,6 +184,7 @@ class GammaPoisson(Distribution):
:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
:param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
"""

arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
Expand Down
20 changes: 11 additions & 9 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ def mean(self):
@property
def variance(self):
con0 = jnp.sum(self.concentration, axis=-1, keepdims=True)
return (
self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1))
)
return self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1))

@staticmethod
def infer_shapes(concentration):
Expand Down Expand Up @@ -909,6 +907,7 @@ def model(y): # y has dimension N x d
[1] `Generating random correlation matrices based on vines and extended onion method`,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""

arg_constraints = {"concentration": constraints.positive}
reparametrized_params = ["concentration"]
support = constraints.corr_matrix
Expand Down Expand Up @@ -985,6 +984,7 @@ def model(y): # y has dimension N x d
[1] `Generating random correlation matrices based on vines and extended onion method`,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""

arg_constraints = {"concentration": constraints.positive}
reparametrized_params = ["concentration"]
support = constraints.corr_cholesky
Expand Down Expand Up @@ -1961,9 +1961,10 @@ def scale_tril(self):

@lazy_property
def covariance_matrix(self):
covariance_matrix = add_diag(jnp.matmul(
self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
), self.cov_diag)
covariance_matrix = add_diag(
jnp.matmul(self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)),
self.cov_diag,
)
return covariance_matrix

@lazy_property
Expand All @@ -1976,7 +1977,7 @@ def precision_matrix(self):
)
A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
inverse_cov_diag = jnp.reciprocal(self.cov_diag)
return add_diag(- jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)
return add_diag(-jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -2068,8 +2069,9 @@ class Pareto(TransformedDistribution):
def __init__(self, scale, alpha, *, validate_args=None):
self.scale, self.alpha = promote_shapes(scale, alpha)
batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha))
scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(
alpha, batch_shape
scale, alpha = (
jnp.broadcast_to(scale, batch_shape),
jnp.broadcast_to(alpha, batch_shape),
)
base_dist = Exponential(alpha)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)]
Expand Down
1 change: 1 addition & 0 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ class Poisson(Distribution):
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""

arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative_integer
pytree_aux_fields = ("is_sparse",)
Expand Down
8 changes: 7 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class CholeskyTransform(ParameterFreeTransform):
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
positive definite matrix.
"""

domain = constraints.positive_definite
codomain = constraints.lower_cholesky

Expand Down Expand Up @@ -444,6 +445,7 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a
c. Applies :math:`s_i = StickBreakingTransform(z_i)`.
d. Transforms back into signed domain: :math:`y_i = (sign(r_i), 1) * \sqrt{s_i}`.
"""

domain = constraints.real_vector
codomain = constraints.corr_cholesky

Expand Down Expand Up @@ -493,6 +495,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform):
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
correlation matrix.
"""

domain = constraints.corr_matrix
codomain = constraints.corr_cholesky

Expand Down Expand Up @@ -624,6 +627,7 @@ class L1BallTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` into the unit L1 ball.
"""

domain = constraints.real_vector
codomain = constraints.l1_ball

Expand Down Expand Up @@ -687,6 +691,7 @@ class LowerCholeskyAffine(Transform):
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
"""

domain = constraints.real_vector
codomain = constraints.real_vector

Expand Down Expand Up @@ -786,6 +791,7 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform):
and :math:`scale\_diag` is a diagonal matrix with all positive
entries that is parameterized with a softplus transform.
"""

domain = constraints.real_vector
codomain = constraints.scaled_unit_lower_cholesky

Expand Down Expand Up @@ -995,6 +1001,7 @@ class SoftplusTransform(ParameterFreeTransform):
Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`.
The inverse is computed as :math:`x = \log(\exp(y) - 1)`.
"""

domain = constraints.real
codomain = constraints.softplus_positive

Expand Down Expand Up @@ -1200,7 +1207,6 @@ def __eq__(self, other):
)



##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
4 changes: 1 addition & 3 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@ def mean(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return (
self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
)
return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
Expand Down
22 changes: 16 additions & 6 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def _setup_prototype(self, *args, **kwargs):
f"Expected {self.batch_ndim} batch dimensions, but site "
f"`{site['name']}` only has shape {shape}."
)
shape = shape[:self.batch_ndim]
shape = shape[: self.batch_ndim]
if batch_shape is None:
batch_shape = shape
elif shape != batch_shape:
Expand All @@ -1884,7 +1884,9 @@ def _get_posterior(self):
)

def _get_reshape_transform(self) -> ReshapeTransform:
return ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape)
return ReshapeTransform(
(self.latent_dim,), self._batch_shape + self._event_shape
)


class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous):
Expand Down Expand Up @@ -1914,7 +1916,10 @@ def __init__(
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
Expand Down Expand Up @@ -2044,16 +2049,21 @@ def __init__(
self._init_scale = init_scale
self.rank = rank
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank
rank = (
int(round(self._event_shape[0] ** 0.5)) if self.rank is None else self.rank
)
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
cov_factor = numpyro.param(
"{}_cov_factor".format(self.prefix),
jnp.zeros(self._batch_shape + self._event_shape + (rank,))
jnp.zeros(self._batch_shape + self._event_shape + (rank,)),
)
scale = numpyro.param(
"{}_scale".format(self.prefix),
Expand Down
Loading
Loading