From 03030c27087a4ca68c7935f61b323dc76b938ebe Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 22 Dec 2023 12:18:45 +0100 Subject: [PATCH] Add ruff support (#1700) * add ruff * fix line-length * min python version * add requires-python section --- Makefile | 7 ++-- docs/source/conf.py | 2 +- numpyro/infer/mcmc.py | 4 +-- pyproject.toml | 68 +++++++++++++++++++++++++++++++++++++ setup.cfg | 20 ----------- setup.py | 4 +-- test/contrib/test_module.py | 31 +++++------------ 7 files changed, 83 insertions(+), 53 deletions(-) create mode 100644 pyproject.toml diff --git a/Makefile b/Makefile index 468e9c391..d654567f8 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,14 @@ all: test lint: FORCE - flake8 - black --check . - isort --check . + ruff . python scripts/update_headers.py --check license: FORCE python scripts/update_headers.py format: license FORCE - black . - isort . + ruff . --fix install: FORCE pip install -e .[dev,doc,test,examples] diff --git a/docs/source/conf.py b/docs/source/conf.py index 66091fcda..07ee81b10 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,7 +45,7 @@ if "READTHEDOCS" not in os.environ: # if developing locally, use numpyro.__version__ as version - from numpyro import __version__ # noqaE402 + from numpyro import __version__ # noqa: E402 version = __version__ diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 0b1fb0858..977343802 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -422,11 +422,11 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): ) init_state = new_init_state if init_state is None else init_state sample_fn, postprocess_fn = self._get_cached_fns() - diagnostics = ( + diagnostics = ( # noqa: E731 lambda x: self.sampler.get_diagnostics_str(x[0]) if is_prng_key(rng_key) else "" - ) # noqa: E731 + ) init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..027ffa243 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,68 @@ +[project] +name = "numpyro" +requires-python = ">=3.9" + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "docs/src", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["E4", "E7", "E9", "F"] +ignore = ["E203"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +[tool.ruff.extend-per-file-ignores] +"numpyro/contrib/tfp/distributions.py" = ["F811"] +"numpyro/distributions/kl.py" = ["F811"] diff --git a/setup.cfg b/setup.cfg index 1864aa10f..8f78cac6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,23 +1,3 @@ -[flake8] -max-line-length = 120 -exclude = docs/src, build, dist, .ipynb_checkpoints -ignore = W503,E203 -per-file-ignores = - numpyro/contrib/tfp/distributions.py:F811 - numpyro/distributions/kl.py:F811 - -[isort] -profile = black -skip_glob = .ipynb_checkpoints -known_first_party = funsor, numpyro, test -known_third_party = opt_einsum -known_jax = flax, haiku, jax, optax, tensorflow_probability -sections = FUTURE, STDLIB, THIRDPARTY, JAX, FIRSTPARTY, LOCALFOLDER -force_sort_within_sections = true -combine_as_imports = true -multi_line_output = 3 -skip=docs - [tool:pytest] filterwarnings = error ignore:numpy.ufunc size changed,:RuntimeWarning diff --git a/setup.py b/setup.py index fc346e037..e6b13e248 100644 --- a/setup.py +++ b/setup.py @@ -50,10 +50,8 @@ ], "test": [ "importlib-metadata<5.0", - "black[jupyter]>=21.8b0", - "flake8", "importlib-metadata<5.0", - "isort>=5.0", + "ruff>=0.1.8", "pytest>=4.1", "pyro-api>=0.1.1", "scipy>=1.9", diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 03513e110..20c8717f3 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -4,13 +4,13 @@ from copy import deepcopy import numpy as np -from numpy.testing import assert_allclose import pytest - from jax import random from jax.tree_util import tree_all, tree_map +from numpy.testing import assert_allclose import numpyro +import numpyro.distributions as dist from numpyro import handlers from numpyro.contrib.module import ( ParamShape, @@ -20,12 +20,9 @@ random_flax_module, random_haiku_module, ) -import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS -pytestmark = pytest.mark.filterwarnings( - "ignore:jax.tree_.+ is deprecated:FutureWarning" -) +pytestmark = pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning") def haiku_model_by_shape(x, y): @@ -119,16 +116,12 @@ def test_haiku_module(): 100, 100, ) - assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == ( - 100, - ) + assert haiku_tr["nn$params"]["value"]["test_haiku_module/w_linear"]["b"].shape == (100,) assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["w"].shape == ( 100, 100, ) - assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == ( - 100, - ) + assert haiku_tr["nn$params"]["value"]["test_haiku_module/x_linear"]["b"].shape == (100,) def test_update_params(): @@ -137,9 +130,7 @@ def test_update_params(): new_params = deepcopy(params) with handlers.seed(rng_seed=0): _update_params(params, new_params, prior) - assert params == { - "a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))} - } + assert params == {"a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))}} tree_all( tree_map( @@ -194,7 +185,7 @@ def test_random_module_mcmc(backend, init, callable_prior): kwargs = {} if callable_prior: - prior = ( + prior = ( # noqa: E731 lambda name, shape: dist.Cauchy() if name == bias_name else dist.Normal() ) else: @@ -206,9 +197,7 @@ def model(data, labels): numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -232,9 +221,7 @@ def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: - x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( - x, is_training=True - ) + x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(x, is_training=True) return x def model():