Skip to content

Commit

Permalink
pyright fixes (#3777)
Browse files Browse the repository at this point in the history
* enable pyright

* fix electronic_structure.plotters

* update pre-commit

* finish `electronic_structure`

* fix `core.surface/tensors/tracjectory`

* fix unit test

* fix core.surface

* add DEBUG tag

* fix core.interface

* fix command_line

* fix some analysis

* fix some analysis

* fix more analysis

* fix np.empty

* fix more analysis

* fix chemenv.utils

* finish chemenv

* revert changes that break tests

* Revert "revert changes that break tests"

This reverts commit de34eda.

* fix unit test

* remove tag

* remove TODO tag

* try downgrade pyright

* Revert "try downgrade pyright"

This reverts commit 079ebca.

* add venv path for pyright

* Revert "add venv path for pyright"

This reverts commit f9a3aeb.

* skip pyright in pre-commit

* increase `requests` timeout from 60s to 600s

* skip tests/ext/test_cod.py in CI

* fix plot_slab weird way to calc zorder

* internal types

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored Apr 23, 2024
1 parent 7e2f16c commit 1367746
Show file tree
Hide file tree
Showing 50 changed files with 390 additions and 208 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

- name: Install dependencies
run: |
pip install --upgrade ruff mypy
pip install --upgrade ruff mypy pyright
- name: ruff
run: |
Expand All @@ -31,7 +31,7 @@ jobs:
ruff format --check .
- name: mypy
run: |
mypy --version
rm -rf .mypy_cache
mypy ${{ github.event.repository.name }}
run: mypy ${{ github.event.repository.name }}

- name: pyright
run: pyright
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ exclude: ^(docs|tests/files|cmd_line|tasks.py)

ci:
autoupdate_schedule: monthly
skip: [mypy]
skip: [mypy, pyright]
autofix_commit_msg: pre-commit auto-fixes
autoupdate_commit_msg: pre-commit autoupdate

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --unsafe-fixes]
Expand All @@ -35,7 +35,7 @@ repos:
additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11

- repo: https://github.com/MarcoGorelli/cython-lint
rev: v0.16.0
rev: v0.16.2
hooks:
- id: cython-lint
args: [--no-pycodestyle]
Expand All @@ -62,3 +62,8 @@ repos:
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.359
hooks:
- id: pyright
2 changes: 1 addition & 1 deletion dev_scripts/update_pt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def gen_iupac_ordering():
def add_electron_affinities():
"""Update the periodic table data file with electron affinities."""

req = requests.get("https://wikipedia.org/wiki/Electron_affinity_(data_page)", timeout=60)
req = requests.get("https://wikipedia.org/wiki/Electron_affinity_(data_page)", timeout=600)
soup = BeautifulSoup(req.text, "html.parser")
table = None
for table in soup.find_all("table"):
Expand Down
14 changes: 7 additions & 7 deletions pymatgen/analysis/adsorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,19 +680,19 @@ def plot_slab(
sites = list(reversed(sites))
coords = np.array(reversed(coords))
# Draw circles at sites and stack them accordingly
for n, coord in enumerate(coords):
radius = sites[n].species.elements[0].atomic_radius * scale
ax.add_patch(patches.Circle(coord[:2] - lattice_sum * (repeat // 2), radius, color="w", zorder=2 * n))
color = color_dict[sites[n].species.elements[0].symbol]
for idx, coord in enumerate(coords):
radius = sites[idx].species.elements[0].atomic_radius * scale
ax.add_patch(patches.Circle(coord[:2] - lattice_sum * (repeat // 2), radius, color="w", zorder=2 * idx))
color = color_dict[sites[idx].species.elements[0].symbol]
ax.add_patch(
patches.Circle(
coord[:2] - lattice_sum * (repeat // 2),
radius,
facecolor=color,
alpha=alphas[n],
alpha=alphas[idx],
edgecolor="k",
lw=0.3,
zorder=2 * n + 1,
zorder=2 * idx + 1,
)
)
# Adsorption sites
Expand All @@ -714,7 +714,7 @@ def plot_slab(
codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
vertices = [(np.array(vert) + corner).tolist() for vert in vertices]
path = Path(vertices, codes)
patch = patches.PathPatch(path, facecolor="none", lw=2, alpha=0.5, zorder=2 * n + 2)
patch = patches.PathPatch(path, facecolor="none", lw=2, alpha=0.5, zorder=2 * len(coords) + 2)
ax.add_patch(patch)
ax.set_aspect("equal")
center = corner + lattice_sum / 2.0
Expand Down
21 changes: 11 additions & 10 deletions pymatgen/analysis/bond_dissociation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def fragment_and_process(self, bonds):
bonds (list): bonds to process.
"""
# Try to split the principle:
fragments: list[MoleculeGraph] = []
try:
frags = self.mol_graph.split_molecule_subgraphs(bonds, allow_reverse=True)
fragments = self.mol_graph.split_molecule_subgraphs(bonds, allow_reverse=True)
frag_success = True

except MolGraphSplitError:
Expand Down Expand Up @@ -166,7 +167,7 @@ def fragment_and_process(self, bonds):
# We shouldn't ever encounter more than one good entry.
raise RuntimeError("There should only be one valid ring opening fragment! Exiting...")
elif len(bonds) == 2:
raise RuntimeError("Should only be trying to break two bonds if multibreak is true! Exiting...")
raise RuntimeError("Should only be trying to break two bonds if multibreak=True! Exiting...")
else:
raise ValueError("No reason to try and break more than two bonds at once! Exiting...")
frag_success = False
Expand All @@ -176,19 +177,19 @@ def fragment_and_process(self, bonds):
# As above, we begin by making sure we haven't already encountered an identical pair of fragments:
frags_done = False
for frag_pair in self.done_frag_pairs:
if frag_pair[0].isomorphic_to(frags[0]):
if frag_pair[1].isomorphic_to(frags[1]):
if frag_pair[0].isomorphic_to(fragments[0]):
if frag_pair[1].isomorphic_to(fragments[1]):
frags_done = True
break
elif frag_pair[1].isomorphic_to(frags[0]) and frag_pair[0].isomorphic_to(frags[1]):
elif frag_pair[1].isomorphic_to(fragments[0]) and frag_pair[0].isomorphic_to(fragments[1]):
frags_done = True
break
if not frags_done:
# If we haven't, we save this pair and search for the relevant fragment entries:
self.done_frag_pairs += [frags]
self.done_frag_pairs += [fragments]
n_entries_for_this_frag_pair = 0
frag1_entries = self.search_fragment_entries(frags[0])
frag2_entries = self.search_fragment_entries(frags[1])
frag1_entries = self.search_fragment_entries(fragments[0])
frag2_entries = self.search_fragment_entries(fragments[1])
frag1_charges_found = []
frag2_charges_found = []
# We then check for our expected charges of each fragment:
Expand All @@ -200,14 +201,14 @@ def fragment_and_process(self, bonds):
frag2_charges_found += [frag2["initial_molecule"]["charge"]]
# If we're missing some of either, tell the user:
if len(frag1_charges_found) < len(self.expected_charges):
bb = BabelMolAdaptor(frags[0].molecule)
bb = BabelMolAdaptor(fragments[0].molecule)
pb_mol = bb.pybel_mol
smiles = pb_mol.write("smi").split()[0]
for charge in self.expected_charges:
if charge not in frag1_charges_found:
warnings.warn(f"Missing {charge=} for fragment {smiles}")
if len(frag2_charges_found) < len(self.expected_charges):
bb = BabelMolAdaptor(frags[1].molecule)
bb = BabelMolAdaptor(fragments[1].molecule)
pb_mol = bb.pybel_mol
smiles = pb_mol.write("smi").split()[0]
for charge in self.expected_charges:
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/bond_valence.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def get_valences(self, structure: Structure):
valences.append(vals)

# make variables needed for recursion
attrib = []
if structure.is_ordered:
n_sites = np.array(list(map(len, equi_sites)))
valence_min = np.array(list(map(min, valences)))
Expand Down Expand Up @@ -324,7 +325,6 @@ def _recurse(assigned=None):
else:
n_sites = np.array([len(sites) for sites in equi_sites])
tmp = []
attrib = []
for idx, n_site in enumerate(n_sites):
for _ in valences[idx]:
tmp.append(n_site)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import abc
import os
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING

import numpy as np
from monty.json import MSONable
Expand All @@ -31,6 +31,8 @@
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

if TYPE_CHECKING:
from typing import ClassVar

from typing_extensions import Self

__author__ = "David Waroquiers"
Expand Down Expand Up @@ -256,6 +258,7 @@ def equivalent_site_index_and_transform(self, psite):
Equivalent site in the unit cell, translations and symmetry transformation.
"""
# Get the index of the site in the unit cell of which the PeriodicSite psite is a replica.
isite = 0
try:
isite = self.structure_environments.structure.index(psite)
except ValueError:
Expand Down Expand Up @@ -283,6 +286,8 @@ def equivalent_site_index_and_transform(self, psite):
# that gets back the site to the unit cell (Translation III)
# TODO: check that these tolerances are needed, now that the structures are refined before analyzing envs
tolerances = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
d_this_site2 = (0, 0, 0)
sym_trafo = None
for tolerance in tolerances:
for sym_op in self.symops:
new_site = PeriodicSite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ def edges(self, sites, permutation=None, input="sites"):
coords = [site.coords for site in sites]
elif input == "coords":
coords = sites
else:
raise RuntimeError("Invalid input for edges.")

# if permutation is None:
# coords = [site.coords for site in sites]
# else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def compute_structure_environments(

to_add_from_hints = []
nb_sets_info = {}
cn = 0

for cn, nb_sets in struct_envs.neighbors_sets[isite].items():
if cn not in all_cns:
Expand Down Expand Up @@ -1534,8 +1535,7 @@ def coordination_geometry_symmetry_measures_separation_plane(
algos = []
perfect2local_maps = []
local2perfect_maps = []
if testing:
separation_permutations = []
separation_permutations = []
nplanes = 0
for npoints in range(
separation_plane_algo.minimum_number_of_points,
Expand Down Expand Up @@ -1793,9 +1793,10 @@ def _cg_csm_separation_plane(
plane_found = False
permutations = []
permutations_symmetry_measures = []
if testing:
separation_permutations = []
separation_permutations = []
dist_tolerances = dist_tolerances or DIST_TOLERANCES
algo = ""

for dist_tolerance in dist_tolerances:
algo = "NOT_FOUND"
separation = local_plane.indices_separate(self.local_geometry._coords, dist_tolerance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2230,6 +2230,7 @@ def __str__(self):
if len(self.coord_geoms) == 0:
out += " => No coordination in it <=\n"
return out
mp_symbol = ""
for key in self.coord_geoms:
mp_symbol = key
break
Expand Down
3 changes: 1 addition & 2 deletions pymatgen/analysis/chemenv/utils/chemenv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ def setup(self):
break
elif test == "S":
config_file = self.save()
print(f"Configuration has been saved to file {config_file!r}")
break
else:
print(" ... wrong key, try again ...")
print()
if test == "S":
print(f"Configuration has been saved to file {config_file!r}")

@property
def has_materials_project_access(self):
Expand Down
18 changes: 16 additions & 2 deletions pymatgen/analysis/chemenv/utils/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def draw_cg(
show_distorted:
faces_color_override:
"""
csm_suffix = ""
perf_radius = 0
if show_perfect:
if csm_info is None:
raise ValueError("Not possible to show perfect environment without csm_info")
Expand Down Expand Up @@ -219,6 +221,7 @@ def compute_environments(chemenv_configuration):
for key_character, qq in questions.items():
print(f" - <{key_character}> for a structure from {string_sources[qq]['string']}")
test = input(" ... ")
source_type = ""
if test == "q":
break
if test not in list(questions):
Expand All @@ -231,23 +234,30 @@ def compute_environments(chemenv_configuration):
continue
else:
source_type = questions[test]

else:
found = False
source_type = next(iter(questions.values()))

input_source = ""
if found and len(questions) > 1:
input_source = test
input_source = test # type: ignore[reportPossiblyUnboundVariable]

structure = None
if source_type == "cif":
if not found:
input_source = input("Enter path to CIF file : ")
parser = CifParser(input_source)
structure = parser.parse_structures(primitive=True)[0]

elif source_type == "mp":
if not found:
input_source = input('Enter materials project id (e.g. "mp-1902") : ')
from pymatgen.ext.matproj import MPRester

with MPRester() as mpr:
structure = mpr.get_structure_by_material_id(input_source)

lgf.setup_structure(structure)
print(f"Computing environments for {structure.reduced_formula} ... ")
se = lgf.compute_structure_environments(maximum_distance_factor=max_dist_factor)
Expand All @@ -258,7 +268,7 @@ def compute_environments(chemenv_configuration):
'("y" or "n", "d" with details, "g" to see the grid) : '
)
strategy = default_strategy
if test in ["y", "d", "g"]:
if test in {"y", "d", "g"}:
strategy.set_structure_environments(se)
for equiv_list in se.equivalent_sites:
site = equiv_list[0]
Expand All @@ -277,6 +287,7 @@ def compute_environments(chemenv_configuration):
comp = site.species
# ce = strategy.get_site_coordination_environment(site)
reduced_formula = comp.get_reduced_formula_and_factor()[0]
the_cg = None
if strategy.uniquely_determines_coordination_environments:
ce = ces[0]
if ce is None:
Expand Down Expand Up @@ -349,6 +360,9 @@ def compute_environments(chemenv_configuration):
vis = StructureVis(show_polyhedron=False, show_unit_cell=True)
vis.show_help = False
first_time = False
else:
vis = None # TODO: following code logic seems buggy

vis.set_structure(se.structure)
strategy.set_structure_environments(se)
for site in se.structure:
Expand Down
4 changes: 4 additions & 0 deletions pymatgen/analysis/chempot_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def _get_2d_plot(self, elements: list[Element], label_stable: bool | None, eleme

if element_padding is not None and element_padding > 0:
new_lims = self._get_new_limits_from_padding(domains, elem_indices, element_padding, self.default_min_limit)
else:
new_lims = []

for formula, pts in domains.items():
formula_elems = set(Composition(formula).elements)
Expand Down Expand Up @@ -326,6 +328,8 @@ def _get_3d_plot(

if element_padding and element_padding > 0:
new_lims = self._get_new_limits_from_padding(domains, elem_indices, element_padding, self.default_min_limit)
else:
new_lims = []

for formula, pts in domains.items():
entry = self.entry_dict[formula]
Expand Down
4 changes: 4 additions & 0 deletions pymatgen/analysis/diffraction/tem.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,16 @@ def get_positions(self, structure: Structure, points: list) -> dict[tuple[int, i
points = self.zone_axis_filter(points)
# first is the max_d, min_r
first_point_dict = self.get_first_point(structure, points)
first_point = (0, 0, 0)
first_d = 0.0
for point, v in first_point_dict.items():
first_point = point
first_d = v
spacings = self.get_interplanar_spacings(structure, points)
# second is the first non-parallel-to-first-point vector when sorted.
# note 000 is "parallel" to every plane vector.
second_point = (0, 0, 0)
second_d = 0.0
for plane in sorted(spacings):
second_point, second_d = plane, spacings[plane]
if not self.is_parallel(structure, first_point, second_point):
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def get_strain_state_dict(strains, stresses, eq_stress=None, tol: float = 1e-10,
if add_eq:
# add zero strain state
mstrains = np.vstack([mstrains, np.zeros(6)])
mstresses = np.vstack([mstresses, veq_stress])
mstresses = np.vstack([mstresses, veq_stress]) # type: ignore[reportPossiblyUnboundVariable]
# sort strains/stresses by strain values
if sort:
mstresses = mstresses[mstrains[:, ind[0]].argsort()]
Expand Down
Loading

0 comments on commit 1367746

Please sign in to comment.