Skip to content

Commit

Permalink
ci(github): add mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Sep 11, 2024
1 parent ac37328 commit 4e8b8fd
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 50 deletions.
15 changes: 14 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,25 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}

mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: |
python -m pip install --upgrade pip
python -m pip install mypy
mypy .
release:
name: Release
runs-on: ubuntu-latest
concurrency:
group: release
needs: test
needs: [test, mypy]
if: github.event_name == 'push' && github.repository == 'kmnhan/erlabpy'
environment:
name: pypi
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ exclude = [
"^docs/",
"^tests/",
"_deprecated/",
"interactive/fermiedge.py",
"interactive/bzplot.py",
"interactive/curvefittingtool.py",
"io/plugins/",
"io/dataloader.py",
]

[[tool.mypy.overrides]]
Expand Down
46 changes: 22 additions & 24 deletions src/erlab/interactive/bzplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __post_init__(self, execute=None):
if execute is None:
execute = True
try:
shell = get_ipython().__class__.__name__ # type: ignore
shell = get_ipython().__class__.__name__ # pyright: ignore[reportUndefinedVariable]
if shell in ["ZMQInteractiveShell", "TerminalInteractiveShell"]:
execute = False
except NameError:
Expand Down Expand Up @@ -150,23 +150,20 @@ def __init__(self, bvec) -> None:
"notrack": True,
},
_1={
"widget": QtWidgets.QLabel(
"𝑥", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑥"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
_2={
"widget": QtWidgets.QLabel(
"𝑦", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑦"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
_3={
"widget": QtWidgets.QLabel(
"𝑧", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑧"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
Expand Down Expand Up @@ -213,23 +210,20 @@ def __init__(self, bvec) -> None:
"notrack": True,
},
_1={
"widget": QtWidgets.QLabel(
"𝑥", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑥"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
_2={
"widget": QtWidgets.QLabel(
"𝑦", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑦"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
_3={
"widget": QtWidgets.QLabel(
"𝑧", alignment=QtCore.Qt.AlignmentFlag.AlignHCenter
),
"widget": QtWidgets.QLabel("𝑧"),
"alignment": QtCore.Qt.AlignmentFlag.AlignHCenter,
"showlabel": False,
"notrack": True,
},
Expand Down Expand Up @@ -346,15 +340,19 @@ def bvec_val(self):
class BZPlotWidget(QtWidgets.QWidget):
def __init__(self, bvec) -> None:
super().__init__()
self.setLayout(QtWidgets.QVBoxLayout(self))
layout = QtWidgets.QVBoxLayout(self)
self.setLayout(layout)

self.set_bvec(bvec, update=False)

self._canvas = FigureCanvas(Figure())
self.layout().addWidget(NavigationToolbar(self._canvas, self))
self.layout().addWidget(self._canvas)
layout.addWidget(NavigationToolbar(self._canvas, self))
layout.addWidget(self._canvas)

self.ax = self._canvas.figure.add_subplot(projection="3d")
self.ax = cast(
mpl_toolkits.mplot3d.axes3d.Axes3D,
self._canvas.figure.add_subplot(projection="3d"),
)
self.ax.axis("off")

self._lc = mpl_toolkits.mplot3d.art3d.Line3DCollection(
Expand All @@ -378,7 +376,7 @@ def __init__(self, bvec) -> None:
p = self.ax.plot(*[(0, bi) for bi in b], "-", c=f"C{i + 1}", clip_on=False)
t = self.ax.text(
*(b + 0.15 * b / np.linalg.norm(b)),
f"$b_{i + 1}$",
s=f"$b_{i + 1}$",
c=f"C{i + 1}",
ha="center",
va="center_baseline",
Expand Down
2 changes: 1 addition & 1 deletion src/erlab/interactive/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def copy_code(self) -> str:
def dtool(data, data_name: str | None = None, *, execute: bool | None = None):
if data_name is None:
try:
data_name = varname.argname("data", func=dtool, vars_only=False) # type: ignore[assignment]
data_name = str(varname.argname("data", func=dtool, vars_only=False))
except varname.VarnameRetrievingError:
data_name = "data"

Expand Down
54 changes: 35 additions & 19 deletions src/erlab/interactive/fermiedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

import joblib
import numpy as np
Expand All @@ -25,6 +25,8 @@
from erlab.utils.parallel import joblib_progress_qt

if TYPE_CHECKING:
from collections.abc import Callable

import lmfit
import scipy.interpolate

Expand Down Expand Up @@ -70,6 +72,8 @@ def set_params(self, data, x0, y0, x1, y1, params) -> None:
return_as="list",
pre_dispatch="n_jobs",
)
self.edge_center: xr.DataArray | None = None
self.edge_stderr: xr.DataArray | None = None

@QtCore.Slot()
def abort_fit(self) -> None:
Expand Down Expand Up @@ -147,7 +151,9 @@ def __init__(
if data_name is None:
try:
self._argnames["data"] = varname.argname(
"data", func=self.__init__, vars_only=False
"data",
func=self.__init__, # type: ignore[misc]
vars_only=False,
)
except varname.VarnameRetrievingError:
self._argnames["data"] = "gold"
Expand All @@ -157,7 +163,9 @@ def __init__(
if data_corr is not None:
try:
self._argnames["data_corr"] = varname.argname(
"data_corr", func=self.__init__, vars_only=False
"data_corr",
func=self.__init__, # type: ignore[misc]
vars_only=False,
)
except varname.VarnameRetrievingError:
self._argnames["data_corr"] = "data_corr"
Expand All @@ -177,7 +185,7 @@ def __init__(
except KeyError:
temp = 30.0

self.params_roi = ROIControls(self.add_roi(0))
self.params_roi = ROIControls(self.aw.add_roi(0))
self.params_edge = ParameterGroup(
{
"T (K)": {"qwtype": "dblspin", "value": temp, "range": (0.0, 400.0)},
Expand All @@ -202,7 +210,9 @@ def __init__(
}
)

self.params_edge.widgets["Fast"].stateChanged.connect(self._toggle_fast)
cast(
QtWidgets.QCheckBox, self.params_edge.widgets["Fast"]
).stateChanged.connect(self._toggle_fast)

self.params_poly = ParameterGroup(
{
Expand Down Expand Up @@ -258,13 +268,13 @@ def __init__(
},
}
)
_auto_check = cast(QtWidgets.QCheckBox, self.params_spl.widgets["Auto"])
self.params_spl.widgets["lambda"].setDisabled(
self.params_spl.widgets["Auto"].checkState() == QtCore.Qt.CheckState.Checked
_auto_check.checkState() == QtCore.Qt.CheckState.Checked
)
self.params_spl.widgets["Auto"].toggled.connect(
_auto_check.toggled.connect(
lambda _: self.params_spl.widgets["lambda"].setDisabled(
self.params_spl.widgets["Auto"].checkState()
== QtCore.Qt.CheckState.Checked
_auto_check.checkState() == QtCore.Qt.CheckState.Checked
)
)

Expand Down Expand Up @@ -315,15 +325,17 @@ def __init__(
# Setup progress bar
self.progress: QtWidgets.QProgressDialog = QtWidgets.QProgressDialog(
labelText="Fitting...",
cancelButtonText="Abort!",
minimum=0,
maximum=100,
parent=self,
minimumDuration=0,
windowModality=QtCore.Qt.WindowModal,
)

self.pbar: QtWidgets.QProgressBar = QtWidgets.QProgressBar()
self.progress.setMinimumDuration(0)
self.progress.setWindowModality(QtCore.Qt.WindowModality.WindowModal)
self.progress.setBar(self.pbar)
self.progress.setFixedSize(self.progress.size())
self.progress.setCancelButtonText("Abort!")
self.progress.canceled.disconnect(self.progress.cancel) # don't auto close
self.progress.canceled.connect(self.abort_fit)
self.progress.setAutoReset(False)
Expand Down Expand Up @@ -351,8 +363,12 @@ def __init__(
self.__post_init__(execute=execute)

def _toggle_fast(self) -> None:
self.params_edge.widgets["T (K)"].setDisabled(self.params_edge.values["Fast"])
self.params_edge.widgets["Fix T"].setDisabled(self.params_edge.values["Fast"])
self.params_edge.widgets["T (K)"].setDisabled(
bool(self.params_edge.values["Fast"])
)
self.params_edge.widgets["Fix T"].setDisabled(
bool(self.params_edge.values["Fast"])
)

def iterated(self, n: int) -> None:
self.step_times.append(time.perf_counter() - self.start_time)
Expand Down Expand Up @@ -396,8 +412,8 @@ def closeEvent(self, event) -> None:
def post_fit(self) -> None:
self.progress.reset()
self.edge_center, self.edge_stderr = (
self.fitter.edge_center,
self.fitter.edge_stderr,
cast(xr.DataArray, self.fitter.edge_center),
cast(xr.DataArray, self.fitter.edge_stderr),
)

xval = self.edge_center.alpha.values
Expand Down Expand Up @@ -484,7 +500,7 @@ def gen_code(self, mode: str) -> None:
p1 = self.params_spl.values
x0, y0, x1, y1 = (float(np.round(x, 3)) for x in self.params_roi.roi_limits)

arg_dict = {
arg_dict: dict[str, Any] = {
"angle_range": (x0, x1),
"eV_range": (y0, y1),
"bin_size": (p0["Bin x"], p0["Bin y"]),
Expand All @@ -494,7 +510,7 @@ def gen_code(self, mode: str) -> None:

match mode:
case "poly":
func = erlab.analysis.gold.poly
func: Callable = erlab.analysis.gold.poly
arg_dict["degree"] = p1["Degree"]
case "spl":
func = erlab.analysis.gold.spline
Expand Down Expand Up @@ -556,7 +572,7 @@ def goldtool(
"""
if data_name is None:
try:
data_name = varname.argname("data", func=goldtool, vars_only=False)
data_name = str(varname.argname("data", func=goldtool, vars_only=False))
except varname.VarnameRetrievingError:
data_name = "data"
return GoldTool(data, data_corr, data_name=data_name, **kwargs)
2 changes: 1 addition & 1 deletion src/erlab/interactive/kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def ktool(
"""Interactive momentum conversion tool."""
if data_name is None:
try:
data_name = varname.argname("data", func=ktool, vars_only=False) # type: ignore[assignment]
data_name = str(varname.argname("data", func=ktool, vars_only=False))
except varname.VarnameRetrievingError:
data_name = "data"

Expand Down
11 changes: 9 additions & 2 deletions src/erlab/interactive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,9 @@ def __init__(

self.global_connect()

def layout(self) -> QtWidgets.QGridLayout:
return cast(QtWidgets.QGridLayout, super().layout())

@staticmethod
def getParameterWidget(
qwtype: (
Expand Down Expand Up @@ -1091,6 +1094,8 @@ def getParameterWidget(

policy = kwargs.pop("policy", None)

alignment = kwargs.pop("alignment", None)

if qwtype == "fitparam":
show_param_label = kwargs.pop("show_param_label", False)
kwargs["show_label"] = show_param_label
Expand Down Expand Up @@ -1133,6 +1138,8 @@ def getParameterWidget(
widget.setFixedHeight(fixedHeight)
if policy is not None:
widget.setSizePolicy(*policy)
if alignment is not None:
widget.setAlignment(alignment)

if value is not None:
widget.setValue(value)
Expand Down Expand Up @@ -1279,7 +1286,7 @@ def __init__(self, roi: pg.ROI, spinbox_kw: dict | None = None, **kwargs) -> Non
},
**kwargs,
)
self.draw_button = self.widgets["drawbtn"]
self.draw_button = cast(QtWidgets.QPushButton, self.widgets["drawbtn"])
self.roi_spin = [self.widgets[i] for i in ["x0", "y0", "x1", "y1"]]
self.roi.sigRegionChanged.connect(self.update_pos)

Expand Down Expand Up @@ -1465,7 +1472,7 @@ def __post_init__(self, execute=None):
if execute is None:
execute = True
try:
shell = get_ipython().__class__.__name__ # type: ignore
shell = get_ipython().__class__.__name__ # pyright: ignore[reportUndefinedVariable]
if shell in ["ZMQInteractiveShell", "TerminalInteractiveShell"]:
execute = False
except NameError:
Expand Down

0 comments on commit 4e8b8fd

Please sign in to comment.