From 0f7329f8323bd1345b982e03b20ddbad2070ae48 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Mon, 19 Aug 2024 09:28:30 +0100 Subject: [PATCH] Move to pyright and fix type errors (#135) * Move to pyright and fix type errors --- pyproject.toml | 32 ++++++++++++++++++++------------ src/scanspec/cli.py | 3 +++ src/scanspec/core.py | 12 ++++++------ src/scanspec/plot.py | 18 ++++++++++++------ src/scanspec/service.py | 3 +-- src/scanspec/sphinxext.py | 3 ++- 6 files changed, 44 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ab4dcf9f..c5b48619 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,11 @@ dev = [ "scanspec[plotting]", "scanspec[service]", "copier", - "mypy", "myst-parser", "pipdeptree", "pre-commit", "pydata-sphinx-theme>=0.12", + "pyright", "pytest", "pytest-cov", "ruff", @@ -61,8 +61,9 @@ name = "Tom Cobb" [tool.setuptools_scm] write_to = "src/scanspec/_version.py" -[tool.mypy] -ignore_missing_imports = true # Ignore missing stubs in imported modules +[tool.pyright] +# strict = ["src", "tests"] +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error @@ -95,12 +96,12 @@ passenv = * allowlist_externals = pytest pre-commit - mypy + pyright sphinx-build sphinx-autobuild commands = pre-commit: pre-commit run --all-files {posargs} - type-checking: mypy src tests {posargs} + type-checking: pyright src tests {posargs} tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs} docs: sphinx-{posargs:build -E --keep-going} -T docs build/html """ @@ -111,14 +112,21 @@ line-length = 88 [tool.ruff.lint] extend-select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self ] ignore = [ "B008", # We use function calls in service arguments ] + +[tool.ruff.lint.per-file-ignores] +# By default, private member access is allowed in tests +# See https://github.com/DiamondLightSource/python-copier-template/issues/154 +# Remove this line to forbid private member access in tests +"tests/**/*" = ["SLF001"] diff --git a/src/scanspec/cli.py b/src/scanspec/cli.py index 7c0e4de3..33f56520 100644 --- a/src/scanspec/cli.py +++ b/src/scanspec/cli.py @@ -25,6 +25,9 @@ def cli(ctx, log_level: str): # if no command is supplied, print the help message if ctx.invoked_subcommand is None: + # We need to prove that cli has been converted to a command + # by the click decorator to keep pyright happy. + assert isinstance(cli, click.Command) click.echo(cli.get_help(ctx)) diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 74e0dffd..6cb62d47 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -35,11 +35,14 @@ StrictConfig: ConfigDict = {"extra": "forbid"} +C = TypeVar("C") +T = TypeVar("T", type, Callable) + def discriminated_union_of_subclasses( - super_cls: type, + super_cls: type[C], discriminator: str = "type", -) -> type: +) -> type[C]: """Add all subclasses of super_cls to a discriminated union. For all subclasses of super_cls, add a discriminator field to identify @@ -137,9 +140,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler): return super_cls -T = TypeVar("T", type, Callable) - - def uses_tagged_union(cls_or_func: T) -> T: """ T = TypeVar("T", type, Callable) @@ -562,7 +562,7 @@ def __init__( self.lengths = np.array([len(f) for f in stack]) #: Index of the end frame, one more than the last index that will be #: produced - self.end_index = np.prod(self.lengths) + self.end_index = int(np.prod(self.lengths)) if num is not None and start + num < self.end_index: self.end_index = start + num diff --git a/src/scanspec/plot.py b/src/scanspec/plot.py index 43311663..e4fc1e8c 100644 --- a/src/scanspec/plot.py +++ b/src/scanspec/plot.py @@ -33,7 +33,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs): # Added here because of https://github.com/matplotlib/matplotlib/issues/21688 def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d - xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) + xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) return np.min(zs) @@ -109,11 +109,17 @@ def plot_spec(spec: Spec[Any], title: str | None = None): # Setup axes if ndims > 2: plt.figure(figsize=(6, 6)) - plt_axes: Axes3D = plt.axes(projection="3d") + plt_axes = plt.axes(projection="3d") plt_axes.grid(False) - plt_axes.set_zlabel(axes[-3]) - plt_axes.set_ylabel(axes[-2]) - plt_axes.view_init(elev=15) + if isinstance(plt_axes, Axes3D): + plt_axes.set_zlabel(axes[-3]) + plt_axes.set_ylabel(axes[-2]) + plt_axes.view_init(elev=15) + else: + raise TypeError( + "Expected matplotlib to create an Axes3D object, " + f"instead got: {plt_axes}" + ) elif ndims == 2: plt.figure(figsize=(6, 6)) plt_axes = plt.axes() @@ -208,7 +214,7 @@ def plot_spec(spec: Spec[Any], title: str | None = None): _plot_arrow(plt_axes, arrow_arr) elif splines: # Plot the starting arrow in the direction of the first point - arrow_arr = [(2 * a[0] - a[1], a[0]) for a in splines[0]] + arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]] _plot_arrow(plt_axes, arrow_arr) else: # First point isn't moving, put a right caret marker diff --git a/src/scanspec/service.py b/src/scanspec/service.py index 8b4b2cc4..b05e4113 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -7,7 +7,6 @@ from fastapi import Body, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from fastapi.responses import JSONResponse from pydantic import Field from pydantic.dataclasses import dataclass @@ -127,7 +126,7 @@ class SmallestStepResponse: @app.post("/valid", response_model=ValidResponse) def valid( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), -) -> ValidResponse | JSONResponse: +) -> ValidResponse: """Validate wether a ScanSpec can produce a viable scan. Args: diff --git a/src/scanspec/sphinxext.py b/src/scanspec/sphinxext.py index ecde40d9..6a1e2630 100644 --- a/src/scanspec/sphinxext.py +++ b/src/scanspec/sphinxext.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +from docutils.statemachine import StringList from matplotlib.sphinxext import plot_directive from . import __version__ @@ -25,7 +26,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective): """Runs `plot_spec` on the ``spec`` definied in the content.""" def run(self): - self.content = ( + self.content = StringList( ["# Example Spec", "", "from scanspec.plot import plot_spec"] + [str(x) for x in self.content] + ["plot_spec(spec)"]